From c60c46758e2dc5a71d93bb9d12b418f077d9c0c4 Mon Sep 17 00:00:00 2001 From: adamkr Date: Fri, 5 Dec 2014 16:09:31 -0800 Subject: [PATCH 1/6] Adding server context for v12 servers along with mock test framework and some tests. --- .../Commands.SqlDatabase.Test.csproj | 16 + .../Database/Cmdlet/DatabaseTestHelper.cs | 5 + .../NewAzureSqlDatabaseServerContextTests.cs | 42 + .../Cmdlet/NewAzureSqlPremiumDatabaseTests.cs | 132 +-- .../Database/Cmdlet/SqlAuthv12MockTests.cs | 140 +++ .../TSql/CustomAttributeProviderExtensions.cs | 87 ++ .../UnitTests/TSql/MockQueryResult.cs | 465 +++++++++ .../UnitTests/TSql/MockSettings.cs | 237 +++++ .../UnitTests/TSql/MockSqlCommand.cs | 625 ++++++++++++ .../UnitTests/TSql/MockSqlConnection.cs | 223 +++++ .../UnitTests/TSql/MockSqlParameter.cs | 115 +++ .../TSql/MockSqlParameterCollection.cs | 269 ++++++ .../TSql/RecordMockDataResultsAttribute.cs | 71 ++ .../UnitTests/UnitTestHelper.cs | 2 + .../Commands.SqlDatabase.csproj | 1 + .../Database/Cmdlet/NewAzureSqlDatabase.cs | 4 +- .../NewAzureSqlDatabaseServerContext.cs | 104 +- .../Server/ServerDataServiceSqlAuth.cs | 2 +- .../Services/Server/TSqlConnectionContext.cs | 886 ++++++++++++++++++ 19 files changed, 3329 insertions(+), 97 deletions(-) create mode 100644 src/ServiceManagement/Sql/Commands.SqlDatabase.Test/UnitTests/Database/Cmdlet/SqlAuthv12MockTests.cs create mode 100644 src/ServiceManagement/Sql/Commands.SqlDatabase.Test/UnitTests/TSql/CustomAttributeProviderExtensions.cs create mode 100644 src/ServiceManagement/Sql/Commands.SqlDatabase.Test/UnitTests/TSql/MockQueryResult.cs create mode 100644 src/ServiceManagement/Sql/Commands.SqlDatabase.Test/UnitTests/TSql/MockSettings.cs create mode 100644 src/ServiceManagement/Sql/Commands.SqlDatabase.Test/UnitTests/TSql/MockSqlCommand.cs create mode 100644 src/ServiceManagement/Sql/Commands.SqlDatabase.Test/UnitTests/TSql/MockSqlConnection.cs create mode 100644 src/ServiceManagement/Sql/Commands.SqlDatabase.Test/UnitTests/TSql/MockSqlParameter.cs create mode 100644 src/ServiceManagement/Sql/Commands.SqlDatabase.Test/UnitTests/TSql/MockSqlParameterCollection.cs create mode 100644 src/ServiceManagement/Sql/Commands.SqlDatabase.Test/UnitTests/TSql/RecordMockDataResultsAttribute.cs create mode 100644 src/ServiceManagement/Sql/Commands.SqlDatabase/Services/Server/TSqlConnectionContext.cs diff --git a/src/ServiceManagement/Sql/Commands.SqlDatabase.Test/Commands.SqlDatabase.Test.csproj b/src/ServiceManagement/Sql/Commands.SqlDatabase.Test/Commands.SqlDatabase.Test.csproj index fdcae8c60965..9bb416f5db95 100644 --- a/src/ServiceManagement/Sql/Commands.SqlDatabase.Test/Commands.SqlDatabase.Test.csproj +++ b/src/ServiceManagement/Sql/Commands.SqlDatabase.Test/Commands.SqlDatabase.Test.csproj @@ -90,6 +90,7 @@ 3.5 + False @@ -132,6 +133,18 @@ + + + + + Component + + + Component + + + + @@ -221,6 +234,9 @@ Designer + + PreserveNewest + diff --git a/src/ServiceManagement/Sql/Commands.SqlDatabase.Test/UnitTests/Database/Cmdlet/DatabaseTestHelper.cs b/src/ServiceManagement/Sql/Commands.SqlDatabase.Test/UnitTests/Database/Cmdlet/DatabaseTestHelper.cs index eb4ef2a0a4d0..0effb49ad348 100644 --- a/src/ServiceManagement/Sql/Commands.SqlDatabase.Test/UnitTests/Database/Cmdlet/DatabaseTestHelper.cs +++ b/src/ServiceManagement/Sql/Commands.SqlDatabase.Test/UnitTests/Database/Cmdlet/DatabaseTestHelper.cs @@ -46,6 +46,11 @@ public static class DatabaseTestHelper /// public static readonly Guid StandardS1SloGuid = new Guid("1b1ebd4d-d903-4baa-97f9-4ea675f5e928"); + /// + /// The unique GUID for identifying the Standard S0 SLO. + /// + public static readonly Guid StandardS0SloGuid = new Guid("f1173c43-91bd-4aaa-973c-54e79e15235b"); + /// /// The unique GUID for identifying the Premium P1 SLO. /// diff --git a/src/ServiceManagement/Sql/Commands.SqlDatabase.Test/UnitTests/Database/Cmdlet/NewAzureSqlDatabaseServerContextTests.cs b/src/ServiceManagement/Sql/Commands.SqlDatabase.Test/UnitTests/Database/Cmdlet/NewAzureSqlDatabaseServerContextTests.cs index 90af9366441c..8e6774c1f32f 100644 --- a/src/ServiceManagement/Sql/Commands.SqlDatabase.Test/UnitTests/Database/Cmdlet/NewAzureSqlDatabaseServerContextTests.cs +++ b/src/ServiceManagement/Sql/Commands.SqlDatabase.Test/UnitTests/Database/Cmdlet/NewAzureSqlDatabaseServerContextTests.cs @@ -339,5 +339,47 @@ public static void CreateServerContextSqlAuth( contextPsObject.BaseObject is ServerDataServiceSqlAuth, "Expecting a ServerDataServiceSqlAuth object"); } + + /// + /// Common helper method for other tests to create a context for ESA server. + /// + /// The variable name that will hold the new context. + public static void CreateServerContextSqlAuthV2( + System.Management.Automation.PowerShell powershell, + string manageUrl, + string username, + string password, + string contextVariable) + { + UnitTestHelper.ImportAzureModule(powershell); + UnitTestHelper.CreateTestCredential( + powershell, + username, + password); + + Collection serverContext; + using (AsyncExceptionManager exceptionManager = new AsyncExceptionManager()) + { + serverContext = powershell.InvokeBatchScript( + string.Format( + CultureInfo.InvariantCulture, + @"{1} = New-AzureSqlDatabaseServerContext " + + @"-ManageUrl {0} " + + @"-Credential $credential " + + @"-Version 12.0 ", + manageUrl, + contextVariable), + contextVariable); + } + + Assert.AreEqual(0, powershell.Streams.Error.Count, "Errors during run!"); + Assert.AreEqual(0, powershell.Streams.Warning.Count, "Warnings during run!"); + powershell.Streams.ClearStreams(); + + PSObject contextPsObject = serverContext.Single(); + Assert.IsTrue( + contextPsObject.BaseObject is TSqlConnectionContext, + "Expecting a TSqlConnectionContext object"); + } } } diff --git a/src/ServiceManagement/Sql/Commands.SqlDatabase.Test/UnitTests/Database/Cmdlet/NewAzureSqlPremiumDatabaseTests.cs b/src/ServiceManagement/Sql/Commands.SqlDatabase.Test/UnitTests/Database/Cmdlet/NewAzureSqlPremiumDatabaseTests.cs index 0f3b019194ab..dc316ae216b0 100644 --- a/src/ServiceManagement/Sql/Commands.SqlDatabase.Test/UnitTests/Database/Cmdlet/NewAzureSqlPremiumDatabaseTests.cs +++ b/src/ServiceManagement/Sql/Commands.SqlDatabase.Test/UnitTests/Database/Cmdlet/NewAzureSqlPremiumDatabaseTests.cs @@ -48,7 +48,7 @@ public void CreatePremiumDatabasesWithSqlAuth() "$context"); HttpSession testSession = MockServerHelper.DefaultSessionCollection.GetSession( "UnitTest.Common.CreatePremiumDatabasesWithSqlAuth"); - DatabaseTestHelper.SetDefaultTestSessionSettings(testSession); + DatabaseTestHelper.SetDefaultTestSessionSettings(testSession); testSession.RequestValidator = new Action( (expected, actual) => @@ -57,69 +57,79 @@ public void CreatePremiumDatabasesWithSqlAuth() Assert.AreEqual(expected.RequestInfo.UserAgent, actual.UserAgent); }); - using (AsyncExceptionManager exceptionManager = new AsyncExceptionManager()) + TestCreatePremiumDatabase(powershell, testSession); + } + } + + /// + /// Helper function to create premium database in the powershell environment provided. + /// + /// The powershell environment + /// The test session + private static void TestCreatePremiumDatabase(System.Management.Automation.PowerShell powershell, HttpSession testSession) + { + using (AsyncExceptionManager exceptionManager = new AsyncExceptionManager()) + { + Collection premiumDB_P1, PremiumDB_P2; + using (new MockHttpServer( + exceptionManager, + MockHttpServer.DefaultServerPrefixUri, + testSession)) { - Collection premiumDB_P1, PremiumDB_P2; - using (new MockHttpServer( - exceptionManager, - MockHttpServer.DefaultServerPrefixUri, - testSession)) - { - powershell.InvokeBatchScript( - @"$P1 = Get-AzureSqlDatabaseServiceObjective" + - @" -Context $context" + - @" -ServiceObjectiveName ""P1"""); - - powershell.InvokeBatchScript( - @"$P2 = Get-AzureSqlDatabaseServiceObjective " + - @"-Context $context" + - @" -ServiceObjectiveName ""P2"""); - - premiumDB_P1 = powershell.InvokeBatchScript( - @"$premiumDB_P1 = New-AzureSqlDatabase " + - @"-Context $context " + - @"-DatabaseName NewAzureSqlPremiumDatabaseTests_P1 " + - @"-Edition Premium " + - @"-ServiceObjective $P1 "); - premiumDB_P1 = powershell.InvokeBatchScript("$PremiumDB_P1"); - - powershell.InvokeBatchScript( - @"$PremiumDB_P2 = New-AzureSqlDatabase " + - @"-Context $context " + - @"-DatabaseName NewAzureSqlPremiumDatabaseTests_P2 " + - @"-Collation Japanese_CI_AS " + - @"-Edition Premium " + - @"-ServiceObjective $P2 " + - @"-MaxSizeGB 10 " + - @"-Force"); - PremiumDB_P2 = powershell.InvokeBatchScript("$PremiumDB_P2"); - } + powershell.InvokeBatchScript( + @"$P1 = Get-AzureSqlDatabaseServiceObjective" + + @" -Context $context" + + @" -ServiceObjectiveName ""P1"""); - Assert.AreEqual(0, powershell.Streams.Error.Count, "Errors during run!"); - Assert.AreEqual(0, powershell.Streams.Warning.Count, "Warnings during run!"); - powershell.Streams.ClearStreams(); - - Assert.IsTrue( - premiumDB_P1.Single().BaseObject is Services.Server.Database, - "Expecting a Database object"); - Services.Server.Database databaseP1 = - (Services.Server.Database)premiumDB_P1.Single().BaseObject; - Assert.AreEqual("NewAzureSqlPremiumDatabaseTests_P1", databaseP1.Name, "Expected db name to be NewAzureSqlPremiumDatabaseTests_P1"); - - Assert.IsTrue( - PremiumDB_P2.Single().BaseObject is Services.Server.Database, - "Expecting a Database object"); - Services.Server.Database databaseP2 = - (Services.Server.Database)PremiumDB_P2.Single().BaseObject; - Assert.AreEqual("NewAzureSqlPremiumDatabaseTests_P2", databaseP2.Name, "Expected db name to be NewAzureSqlPremiumDatabaseTests_P2"); - - Assert.AreEqual( - "Japanese_CI_AS", - databaseP2.CollationName, - "Expected collation to be Japanese_CI_AS"); - Assert.AreEqual("Premium", databaseP2.Edition, "Expected edition to be Premium"); - Assert.AreEqual(10, databaseP2.MaxSizeGB, "Expected max size to be 10 GB"); + powershell.InvokeBatchScript( + @"$P2 = Get-AzureSqlDatabaseServiceObjective " + + @"-Context $context" + + @" -ServiceObjectiveName ""P2"""); + + premiumDB_P1 = powershell.InvokeBatchScript( + @"$premiumDB_P1 = New-AzureSqlDatabase " + + @"-Context $context " + + @"-DatabaseName NewAzureSqlPremiumDatabaseTests_P1 " + + @"-Edition Premium " + + @"-ServiceObjective $P1 "); + premiumDB_P1 = powershell.InvokeBatchScript("$PremiumDB_P1"); + + powershell.InvokeBatchScript( + @"$PremiumDB_P2 = New-AzureSqlDatabase " + + @"-Context $context " + + @"-DatabaseName NewAzureSqlPremiumDatabaseTests_P2 " + + @"-Collation Japanese_CI_AS " + + @"-Edition Premium " + + @"-ServiceObjective $P2 " + + @"-MaxSizeGB 10 " + + @"-Force"); + PremiumDB_P2 = powershell.InvokeBatchScript("$PremiumDB_P2"); } + + Assert.AreEqual(0, powershell.Streams.Error.Count, "Errors during run!"); + Assert.AreEqual(0, powershell.Streams.Warning.Count, "Warnings during run!"); + powershell.Streams.ClearStreams(); + + Assert.IsTrue( + premiumDB_P1.Single().BaseObject is Services.Server.Database, + "Expecting a Database object"); + Services.Server.Database databaseP1 = + (Services.Server.Database)premiumDB_P1.Single().BaseObject; + Assert.AreEqual("NewAzureSqlPremiumDatabaseTests_P1", databaseP1.Name, "Expected db name to be NewAzureSqlPremiumDatabaseTests_P1"); + + Assert.IsTrue( + PremiumDB_P2.Single().BaseObject is Services.Server.Database, + "Expecting a Database object"); + Services.Server.Database databaseP2 = + (Services.Server.Database)PremiumDB_P2.Single().BaseObject; + Assert.AreEqual("NewAzureSqlPremiumDatabaseTests_P2", databaseP2.Name, "Expected db name to be NewAzureSqlPremiumDatabaseTests_P2"); + + Assert.AreEqual( + "Japanese_CI_AS", + databaseP2.CollationName, + "Expected collation to be Japanese_CI_AS"); + Assert.AreEqual("Premium", databaseP2.Edition, "Expected edition to be Premium"); + Assert.AreEqual(10, databaseP2.MaxSizeGB, "Expected max size to be 10 GB"); } } diff --git a/src/ServiceManagement/Sql/Commands.SqlDatabase.Test/UnitTests/Database/Cmdlet/SqlAuthv12MockTests.cs b/src/ServiceManagement/Sql/Commands.SqlDatabase.Test/UnitTests/Database/Cmdlet/SqlAuthv12MockTests.cs new file mode 100644 index 000000000000..97a6ad54c7fa --- /dev/null +++ b/src/ServiceManagement/Sql/Commands.SqlDatabase.Test/UnitTests/Database/Cmdlet/SqlAuthv12MockTests.cs @@ -0,0 +1,140 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using Microsoft.WindowsAzure.Commands.SqlDatabase.Services.Server; +using Microsoft.WindowsAzure.Commands.SqlDatabase.Test.UnitTests.TSql; +using System; +using System.Collections.Generic; +using System.Collections.ObjectModel; +using System.Linq; +using System.Management.Automation; +using System.Text; +using System.Threading.Tasks; + +namespace Microsoft.WindowsAzure.Commands.SqlDatabase.Test.UnitTests.Database.Cmdlet +{ + [RecordMockDataResults("./")] + [TestClass] + public class SqlAuthv12MockTests + { + public static string username = "testlogin"; + public static string password = "MyS3curePa$$w0rd"; + public static string manageUrl = "https://mysvr2.adamkr-vm04.onebox.xdb.mscds.com"; + + [TestInitialize] + public void Setup() + { + var mockConn = new MockSqlConnection(); + TSqlConnectionContext.MockSqlConnection = mockConn; + } + + [TestCleanup] + public void Cleanup() + { + // Do any test clean up here. + } + + [TestMethod] + public void NewAzureSqlDatabaseWithSqlAuthv12() + { + + using (System.Management.Automation.PowerShell powershell = + System.Management.Automation.PowerShell.Create()) + { + + // Create a context + NewAzureSqlDatabaseServerContextTests.CreateServerContextSqlAuthV2( + powershell, + manageUrl, + username, + password, + "$context"); + + Collection database1, database2, database3, database4; + + database1 = powershell.InvokeBatchScript( + @"$testdb1 = New-AzureSqlDatabase " + + @"-Context $context " + + @"-DatabaseName testdb1 " + + @"-Force", + @"$testdb1"); + database2 = powershell.InvokeBatchScript( + @"$testdb2 = New-AzureSqlDatabase " + + @"-Context $context " + + @"-DatabaseName testdb2 " + + @"-Collation Japanese_CI_AS " + + @"-Edition Basic " + + @"-MaxSizeGB 2 " + + @"-Force", + @"$testdb2"); + database3 = powershell.InvokeBatchScript( + @"$testdb3 = New-AzureSqlDatabase " + + @"-Context $context " + + @"-DatabaseName testdb3 " + + @"-MaxSizeBytes 107374182400 " + + @"-Force", + @"$testdb3"); + var slo = powershell.InvokeBatchScript( + @"$so = Get-AzureSqlDatabaseServiceObjective " + + @"-Context $context " + + @"-ServiceObjectiveName S2 ", + @"$so"); + database4 = powershell.InvokeBatchScript( + @"$testdb4 = New-AzureSqlDatabase " + + @"-Context $context " + + @"-DatabaseName testdb4 " + + @"-Edition Standard " + + @"-ServiceObjective $so " + + @"-Force", + @"$testdb4"); + + Assert.AreEqual(0, powershell.Streams.Error.Count, "Errors during run!"); + Assert.AreEqual(0, powershell.Streams.Warning.Count, "Warnings during run!"); + powershell.Streams.ClearStreams(); + + Services.Server.Database database = database1.Single().BaseObject as Services.Server.Database; + Assert.IsTrue(database != null, "Expecting a Database object"); + ValidateDatabaseProperties(database, "testdb1", "Standard", 250, 268435456000L, "SQL_Latin1_General_CP1_CI_AS", false, DatabaseTestHelper.StandardS0SloGuid); + + database = database2.Single().BaseObject as Services.Server.Database; + Assert.IsTrue(database != null, "Expecting a Database object"); + ValidateDatabaseProperties(database, "testdb2", "Basic", 2, 2147483648L, "Japanese_CI_AS", false, DatabaseTestHelper.BasicSloGuid); + + database = database3.Single().BaseObject as Services.Server.Database; + Assert.IsTrue(database != null, "Expecting a Database object"); + ValidateDatabaseProperties(database, "testdb3", "Standard", 100, 107374182400L, "SQL_Latin1_General_CP1_CI_AS", false, DatabaseTestHelper.StandardS0SloGuid); + + database = database4.Single().BaseObject as Services.Server.Database; + Assert.IsTrue(database != null, "Expecting a Database object"); + ValidateDatabaseProperties(database, "testdb4", "Standard", 250, 268435456000L, "SQL_Latin1_General_CP1_CI_AS", false, DatabaseTestHelper.StandardS2SloGuid); + } + } + + + /// + /// Validate the properties of a database against the expected values supplied as input. + /// + /// The database object to validate + /// The expected name of the database + /// The expected edition of the database + /// The expected max size of the database in GB + /// The expected Collation of the database + /// Whether or not the database is expected to be a system object. + internal static void ValidateDatabaseProperties( + Services.Server.Database database, + string name, + string edition, + int maxSizeGb, + long maxSizeBytes, + string collation, + bool isSystem, + Guid slo) + { + Assert.AreEqual(name, database.Name); + Assert.AreEqual(edition, database.Edition); + Assert.AreEqual(maxSizeGb, database.MaxSizeGB); + Assert.AreEqual(maxSizeBytes, database.MaxSizeBytes); + Assert.AreEqual(collation, database.CollationName); + Assert.AreEqual(isSystem, database.IsSystemObject); + // Assert.AreEqual(slo, database.ServiceObjectiveId); + } + } +} diff --git a/src/ServiceManagement/Sql/Commands.SqlDatabase.Test/UnitTests/TSql/CustomAttributeProviderExtensions.cs b/src/ServiceManagement/Sql/Commands.SqlDatabase.Test/UnitTests/TSql/CustomAttributeProviderExtensions.cs new file mode 100644 index 000000000000..1e65d00cb5a6 --- /dev/null +++ b/src/ServiceManagement/Sql/Commands.SqlDatabase.Test/UnitTests/TSql/CustomAttributeProviderExtensions.cs @@ -0,0 +1,87 @@ +using System; +using System.Reflection; + +namespace Microsoft.SqlServer.Management.Relational.Domain.UnitTest +{ + /// + /// Class that extends ICustomAttributeProvider + /// to allow for type-safe access to custom attributes. + /// + internal static class CustomAttributeProviderExtensions + { + /// + /// Retrieves a list of custom attributes of given type. + /// + /// Type of attribute to retrieve. + /// Object from which to request the custom attributes. + /// Array of attributes + public static T[] GetCustomAttributes(this ICustomAttributeProvider provider) where T : Attribute + { + return GetCustomAttributes(provider, false); + } + + /// + /// Retrieves a list of custom attributes of given type. + /// + /// Type of attribute to retrieve. + /// Object from which to request the custom attributes. + /// Specifies wheather the attributes can be inherited from parent object + /// Array of attributes + public static T[] GetCustomAttributes(this ICustomAttributeProvider provider, bool inherit) where T : Attribute + { + if (provider == null) + { + throw new ArgumentNullException("provider"); + } + + T[] attributes = provider.GetCustomAttributes(typeof(T), inherit) as T[]; + if (attributes == null) + { + return new T[0]; + } + + return attributes; + } + + /// + /// Retrieves a single custom attribute of given type. + /// + /// Type of attribute to retrieve. + /// Object from which to request the custom attributes. + /// An attribute obtained or null. + public static T GetCustomAttribute(this ICustomAttributeProvider provider) where T : Attribute + { + return GetCustomAttribute(provider, false); + } + + /// + /// Retrieves a single custom attribute of given type. + /// + /// Type of attribute to retrieve. + /// Object from which to request the custom attributes. + /// Specifies wheather the attributes can be inherited from parent object + /// An attribute obtained or null. + public static T GetCustomAttribute(this ICustomAttributeProvider provider, bool inherit) where T : Attribute + { + T[] attributes = GetCustomAttributes(provider, inherit); + + if (attributes.Length > 1) + { + throw new InvalidOperationException( + string.Format( + "Domain element is expected to contain 1 attribute(s) of type [{1}], but it contains {0} attribute(s).", + attributes.Length, + typeof(T).Name)); + } + + if (attributes.Length == 1) + { + return attributes[0]; + } + else + { + return null; + } + } + } +} \ No newline at end of file diff --git a/src/ServiceManagement/Sql/Commands.SqlDatabase.Test/UnitTests/TSql/MockQueryResult.cs b/src/ServiceManagement/Sql/Commands.SqlDatabase.Test/UnitTests/TSql/MockQueryResult.cs new file mode 100644 index 000000000000..c4ae181201da --- /dev/null +++ b/src/ServiceManagement/Sql/Commands.SqlDatabase.Test/UnitTests/TSql/MockQueryResult.cs @@ -0,0 +1,465 @@ +// ---------------------------------------------------------------------------------- +// +// Copyright Microsoft Corporation +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ---------------------------------------------------------------------------------- + +using System; +using System.Collections.Generic; +using System.Data; +using System.Data.SqlClient; +using System.IO; +using System.Reflection; +using System.Runtime.Serialization.Formatters.Binary; +using System.Xml; +using System.Xml.Serialization; + +namespace Microsoft.WindowsAzure.Commands.SqlDatabase.Test.UnitTests.TSql +{ + [Serializable] + public class MockQueryResult + { + /// + /// Mock Query id + /// + private string mockId; + + /// + /// Name of the database being queried + /// + private string databaseName; + + /// + /// The command being executed + /// + private string commandText; + + /// + /// The result if the result is a scalar + /// + private object scalarResult; + + /// + /// The result if the result is a data set + /// + private MockDataSet dataSetResult; + + /// + /// The result if an exception is thrown + /// + private MockException exceptionResult; + + /// + /// Gets or sets the mock id for this query + /// + [XmlElement] + public string MockId + { + get { return this.mockId; } + set { this.mockId = value; } + } + + /// + /// Gets or sets the name of the database being queried + /// + [XmlElement] + public string DatabaseName + { + get { return this.databaseName; } + set { this.databaseName = value; } + } + + /// + /// Gets or sets the command that is being recorded + /// + [XmlElement] + public string CommandText + { + get { return this.commandText; } + set { this.commandText = value; } + } + + /// + /// Gets or sets the scalar result of the query + /// + [XmlElement] + public object ScalarResult + { + get { return this.scalarResult; } + set { this.scalarResult = value; } + } + + /// + /// Gets or sets the data set result of the query + /// + [XmlElement] + public MockDataSet DataSetResult + { + get { return this.dataSetResult; } + set { this.dataSetResult = value; } + } + + /// + /// gets or sets the exception result of the query + /// + [XmlElement] + public MockException ExceptionResult + { + get { return this.exceptionResult; } + set { this.exceptionResult = value; } + } + } + + /// + /// Represents a dataset that was recorded + /// + public class MockDataSet : IXmlSerializable + { + /// + /// The data set + /// + private DataSet dataSet; + + /// + /// Constructor + /// + public MockDataSet() + { + } + + /// + /// Constructor that sets the data set + /// + /// The dataset to initialize this instance with + public MockDataSet(DataSet dataSet) + { + this.dataSet = dataSet; + } + + /// + /// Gets or sets the data set for the mock + /// + public DataSet DataSet + { + get { return this.dataSet; } + set { this.dataSet = value; } + } + + #region IXmlSerializable Members + + /// + /// No schema + /// + /// Always null + public System.Xml.Schema.XmlSchema GetSchema() + { + return null; + } + + /// + /// Helper function to parse the xml into a dataset + /// + /// + public void ReadXml(System.Xml.XmlReader reader) + { + this.dataSet = new DataSet(); + this.dataSet.ReadXml(reader, XmlReadMode.ReadSchema); + } + + /// + /// Helper function to write the dataset to XML + /// + /// + public void WriteXml(System.Xml.XmlWriter writer) + { + this.dataSet.WriteXml(writer, XmlWriteMode.WriteSchema); + } + + #endregion + } + + /// + /// Represents an exception in the mock + /// + public class MockException : IXmlSerializable + { + /// + /// Type key for binary exception data + /// + private const string BinaryExceptionTypeKey = "MockExceptionBinary"; + + /// + /// Type key for sql exception data + /// + private const string SqlExceptionTypeKey = "SqlException"; + + /// + /// The exception + /// + private Exception exception; + + /// + /// C'tor. + /// + public MockException() + { + } + + /// + /// Constructor that initializes the class with an exception + /// + /// The exception data to initialize the instance with + public MockException(Exception exception) + { + this.exception = exception; + } + + /// + /// Gets or sets the exception + /// + public Exception Exception + { + get { return this.exception; } + set { this.exception = value; } + } + + #region IXmlSerializable Members + + /// + /// No schema + /// + /// Always null + public System.Xml.Schema.XmlSchema GetSchema() + { + return null; + } + + /// + /// Helper function to get data from the XmlReader and transform it into a MockException + /// + /// The reader to read from + public void ReadXml(System.Xml.XmlReader reader) + { + reader.ReadStartElement(); + string mockExceptionType = reader.Name; + if (mockExceptionType == BinaryExceptionTypeKey) + { + // Deserialize a binary serialized exception. + reader.ReadStartElement(); + BinaryFormatter formatter = new BinaryFormatter(); + MemoryStream stream = new MemoryStream(System.Convert.FromBase64String(reader.ReadContentAsString())); + this.exception = (Exception)formatter.Deserialize(stream); + reader.ReadEndElement(); + } + else if (mockExceptionType == SqlExceptionTypeKey) + { + // Deserialize a SqlException. + this.exception = DeserializeSqlException(reader); + } + else + { + // Unknown mock exception type + throw new XmlException(string.Format("Unknown mock exception type '{0}' in mock query result.", mockExceptionType)); + } + reader.ReadEndElement(); + } + + /// + /// Helper function to store the MockException in xml + /// + /// + public void WriteXml(System.Xml.XmlWriter writer) + { + Type exceptionType = this.exception.GetType(); + if (exceptionType == typeof(SqlException)) + { + // For SqlExceptions, serialize in text form, for easy viewing/editing + writer.WriteStartElement(SqlExceptionTypeKey); + if (exception.Data.Contains("HelpLink.ProdVer")) + { + writer.WriteStartElement("serverVersion"); + writer.WriteValue(exception.Data["HelpLink.ProdVer"]); + writer.WriteEndElement(); + } + foreach (SqlError error in ((SqlException)exception).Errors) + { + writer.WriteStartElement("SqlError"); + foreach (KeyValuePair pair in new KeyValuePair[]{ + new KeyValuePair("infoNumber", error.Number.ToString()), + new KeyValuePair("errorState", error.State.ToString()), + new KeyValuePair("errorClass", error.Class.ToString()), + new KeyValuePair("server", error.Server), + new KeyValuePair("errorMessage", error.Message), + new KeyValuePair("procedure", error.Procedure), + new KeyValuePair("lineNumber", error.LineNumber.ToString())}) + { + writer.WriteStartElement(pair.Key); + writer.WriteValue(pair.Value); + writer.WriteEndElement(); + } + writer.WriteEndElement(); + } + writer.WriteEndElement(); + } + else if (exceptionType.IsSerializable) + { + // For any other serializable exceptions, use the BinaryFormatter to generate serialize it in binary form, and save it in Xml as Base64. + MemoryStream stream = new MemoryStream(); + BinaryFormatter formatter = new BinaryFormatter(); + formatter.Serialize(stream, this.exception); + string serializedException = System.Convert.ToBase64String(stream.ToArray()); + writer.WriteStartElement(BinaryExceptionTypeKey); + writer.WriteValue(serializedException); + writer.WriteEndElement(); + } + else + { + // Non-Serializable exceptions, nothing can be done at this time + throw new XmlException(string.Format("Unknown mock exception type '{0}' for serialization.", exceptionType.ToString())); + } + } + + #endregion + + #region Deserializer Helpers + + /// + /// Custom helper to deserialize a SqlException object from Xml. + /// + private static SqlException DeserializeSqlException(System.Xml.XmlReader reader) + { + // SqlException constructor takes in two parameters, an errorCollection and a serverVersion. + SqlErrorCollection errorCollection = (SqlErrorCollection)typeof(SqlErrorCollection).GetConstructor(BindingFlags.NonPublic | BindingFlags.Instance, null, System.Type.EmptyTypes, null).Invoke(null); ; + string serverVersion = null; + + // Read the subtree and fill in the parameters. + int startDepth = reader.Depth; + reader.ReadStartElement(); + while (reader.Depth > startDepth) + { + switch (reader.Name) + { + case "serverVersion": + serverVersion = reader.ReadElementContentAsString(); + break; + case "SqlError": + SqlError newSqlError = DeserializeSqlError(reader); + errorCollection.GetType().GetMethod("Add", BindingFlags.NonPublic | BindingFlags.Instance).Invoke(errorCollection, new object[] { newSqlError }); + break; + } + } + reader.ReadEndElement(); + + // Use reflection to create the SqlException. + Type sqlExceptionType = typeof(SqlException); + Type[] types = { typeof(SqlErrorCollection), typeof(String) }; + MethodInfo info = sqlExceptionType.GetMethod("CreateException", BindingFlags.Static | BindingFlags.NonPublic, null, types, null); + return (SqlException)info.Invoke(null, new object[] { errorCollection, serverVersion }); + } + + /// + /// Custom helper to deserialize a SqlError object from Xml. + /// + private static SqlError DeserializeSqlError(System.Xml.XmlReader reader) + { + Dictionary sqlErrorParameters = new Dictionary(); + + // Read the subtree and fill in the parameters. + int startDepth = reader.Depth; + reader.ReadStartElement(); + while (reader.Depth > startDepth) + { + string name = reader.Name; + string value = reader.ReadElementContentAsString(); + sqlErrorParameters.Add(name, value); + } + reader.ReadEndElement(); + // Make sure all parameters were defined. + if ((!sqlErrorParameters.ContainsKey("infoNumber")) || + (!sqlErrorParameters.ContainsKey("errorState")) || + (!sqlErrorParameters.ContainsKey("errorClass")) || + (!sqlErrorParameters.ContainsKey("server")) || + (!sqlErrorParameters.ContainsKey("errorMessage")) || + (!sqlErrorParameters.ContainsKey("procedure")) || + (!sqlErrorParameters.ContainsKey("lineNumber"))) + { + // Incomplete definition + throw new XmlException("Incomplete definition of 'SqlError' in mock query result."); + } + + // Using reflection to create a new SqlError object. + SqlError newSqlError = (SqlError)typeof(SqlError).GetConstructor(BindingFlags.NonPublic | BindingFlags.Instance, null, new Type[] { + typeof(int), typeof(byte), typeof(byte), typeof(string), typeof(string), typeof(string), typeof(int) }, null).Invoke(new object[]{ + int.Parse(sqlErrorParameters["infoNumber"]), + byte.Parse(sqlErrorParameters["errorState"]), + byte.Parse(sqlErrorParameters["errorClass"]), + sqlErrorParameters["server"], + sqlErrorParameters["errorMessage"], + sqlErrorParameters["procedure"], + int.Parse(sqlErrorParameters["lineNumber"])}); + return newSqlError; + } + + #endregion + + } + + /// + /// Represents a mock query result set + /// + [Serializable] + public class MockQueryResultSet + { + /// + /// a list of all the results for the command + /// + private List commandResults = new List(); + + /// + /// Gets or sets the list of all the results for the command + /// + [XmlElement("MockQueryResult")] + public List CommandResults + { + get { return this.commandResults; } + set { this.commandResults = value; } + } + + /// + /// Helper function to deserialize the mock query result set from the stream + /// + /// Stream containing query results + /// An instance of + public static MockQueryResultSet Deserialize(Stream stream) + { + XmlSerializer serializer = new XmlSerializer(typeof(MockQueryResultSet)); + using (StreamReader reader = new StreamReader(stream)) + { + return (MockQueryResultSet)serializer.Deserialize(reader); + } + } + + + /// + /// Serializes the provided into the stream + /// + /// Where to output the serialization + /// What to serialize + public static void Serialize(Stream stream, MockQueryResultSet value) + { + XmlSerializer serializer = new XmlSerializer(typeof(MockQueryResultSet)); + using (StreamWriter writer = new StreamWriter(stream)) + { + serializer.Serialize(writer, value); + } + } + } +} diff --git a/src/ServiceManagement/Sql/Commands.SqlDatabase.Test/UnitTests/TSql/MockSettings.cs b/src/ServiceManagement/Sql/Commands.SqlDatabase.Test/UnitTests/TSql/MockSettings.cs new file mode 100644 index 000000000000..e710e0ca2e82 --- /dev/null +++ b/src/ServiceManagement/Sql/Commands.SqlDatabase.Test/UnitTests/TSql/MockSettings.cs @@ -0,0 +1,237 @@ +// ---------------------------------------------------------------------------------- +// +// Copyright Microsoft Corporation +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ---------------------------------------------------------------------------------- + +using Microsoft.VisualStudio.TestTools.UnitTesting; +using System; +using System.Collections.Generic; +using System.Data.SqlClient; +using System.Diagnostics; +using System.Linq; +using System.Reflection; +using System.Text; +using System.Threading.Tasks; + +namespace Microsoft.WindowsAzure.Commands.SqlDatabase.Test.UnitTests.TSql +{ + internal sealed class MockSettings + { + private string mockId; + + /// + /// Whether we are in recording mode or not + /// + private bool recordingMode; + + ///// + ///// The sql connection string being used + ///// + //private string sqlConnectionString; + + /// + /// Where to save the recordings + /// + private string outputPath; + + /// + /// Whether to use isolated queries + /// + private bool isolatedQueries; + + /// + /// Delegate to be called when initializing a sql connection + /// + private SetupMethodDelegate initializeMethod; + + /// + /// Called when closing a sql connection + /// + private SetupMethodDelegate cleanupMethod; + + /// + /// Called to do test setup + /// + /// + public delegate void SetupMethodDelegate(SqlConnection connection); + + /// + /// Constructor + /// + private MockSettings() + { + } + + /// + /// Gets the mock id + /// + public string MockId + { + get { return this.mockId; } + } + + /// + /// Gets the current recording mode + /// + public bool RecordingMode + { + get { return this.recordingMode; } + } + + /// + /// Gets the output path for the test recordings + /// + public string OutputPath + { + get { return this.outputPath; } + } + + /// + /// Gets whether or not isolated queries are used + /// + public bool IsolatedQueries + { + get { return this.isolatedQueries; } + } + + /// + /// Gets the initialize delegate + /// + public SetupMethodDelegate InitializeMethod + { + get { return this.initializeMethod; } + } + + /// + /// Gets the cleanup delegate + /// + public SetupMethodDelegate CleanupMethod + { + get { return this.cleanupMethod; } + } + + /// + /// Gets all the settings for the mock session + /// + /// + public static MockSettings RetrieveSettings() + { + StackTrace stackTrace = new StackTrace(); + StackFrame[] stackFrames = stackTrace.GetFrames(); + + var testMethodFrames = ( + from StackFrame frame in stackFrames + where (frame.GetMethod().GetCustomAttribute() != null) || + (frame.GetMethod().GetCustomAttribute() != null) || + (frame.GetMethod().GetCustomAttribute() != null) || + (frame.GetMethod().GetCustomAttribute() != null) || + (frame.GetMethod().GetCustomAttribute() != null) + select frame).ToArray(); + + StackFrame testMethodFrame = testMethodFrames.FirstOrDefault(); + MockSettings settings = new MockSettings(); + + if (testMethodFrame != null) + { + settings.mockId = GetMockId(testMethodFrame); + //settings.initializeMethod = FindMockSetupMethod(testMethodFrame); + //settings.cleanupMethod = FindMockSetupMethod(testMethodFrame); + + RecordMockDataResultsAttribute recordAttr = FindRecordMockDataResultsAttribute(testMethodFrame); + if (recordAttr != null) + { + settings.recordingMode = true; + settings.outputPath = recordAttr.OutputPath; + settings.isolatedQueries = recordAttr.IsolatedQueries; + } + } + else + { + // Leave the rest of settings as defaults (nulls and false) + } + + return settings; + } + + private static RecordMockDataResultsAttribute FindRecordMockDataResultsAttribute(StackFrame testMethodFrame) + { + MethodBase testMethod = testMethodFrame.GetMethod(); + + // Try to find RecordMockDataResultsAttribute. + // 1) On the method: + RecordMockDataResultsAttribute recordAttr = testMethod.GetCustomAttribute(); + + // 2) On nearest of the enclosing types. + if (recordAttr == null) + { + for (Type currentType = testMethod.DeclaringType; currentType != null; currentType = currentType.DeclaringType) + { + recordAttr = currentType.GetCustomAttribute(); + + if (recordAttr != null) + break; + } + } + + // 3) On the test assembly + if (recordAttr == null) + { + recordAttr = testMethod.DeclaringType.Assembly.GetCustomAttribute(); + } + + // 4) On the executing assembly + if (recordAttr == null) + { + recordAttr = Assembly.GetExecutingAssembly().GetCustomAttribute(); + } + + return recordAttr; + } + + //private static SetupMethodDelegate FindMockSetupMethod(StackFrame testMethodFrame) + // where T : Attribute + //{ + // Type declaringType = testMethodFrame.GetMethod().DeclaringType; + // MethodInfo[] methods = declaringType.GetMethods(); + + // foreach (MethodInfo method in methods) + // { + // if (method.GetCustomAttribute() != null) + // { + // if (!method.IsStatic) + // { + // throw new NotSupportedException("Non-static mock setup method are not supported."); + // } + + // return delegate(SqlConnection connection) + // { + // method.Invoke(null, new object[] { connection }); + // }; + // } + // } + + // return null; + //} + + private static string GetMockId(StackFrame testMethodFrame) + { + List parts = new List(); + + for (Type type = testMethodFrame.GetMethod().DeclaringType; type != null; type = type.DeclaringType) + { + parts.Insert(0, type.Name); + } + + return String.Join(".", parts.ToArray()); + } + } +} diff --git a/src/ServiceManagement/Sql/Commands.SqlDatabase.Test/UnitTests/TSql/MockSqlCommand.cs b/src/ServiceManagement/Sql/Commands.SqlDatabase.Test/UnitTests/TSql/MockSqlCommand.cs new file mode 100644 index 000000000000..a3ebd597430b --- /dev/null +++ b/src/ServiceManagement/Sql/Commands.SqlDatabase.Test/UnitTests/TSql/MockSqlCommand.cs @@ -0,0 +1,625 @@ +// ---------------------------------------------------------------------------------- +// +// Copyright Microsoft Corporation +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ---------------------------------------------------------------------------------- + +using Microsoft.VisualStudio.TestTools.UnitTesting; +using System; +using System.Collections.Generic; +using System.Data; +using System.Data.Common; +using System.Data.SqlClient; +using System.IO; +using System.Linq; +using System.Reflection; +using System.Text; +using System.Text.RegularExpressions; +using System.Threading.Tasks; + +namespace Microsoft.WindowsAzure.Commands.SqlDatabase.Test.UnitTests.TSql +{ + /// + /// Mock sql command for recording sessions for test playback + /// + internal class MockSqlCommand : DbCommand + { + /// + /// A dictionary that stores results for both scalar and reader queries. + /// The dictionary maps query text -> list of candidate results. + /// + private static Dictionary> mockResults = new Dictionary>(); + + /// + /// Regular expression to be used in normalizing the query white spaces + /// + private static readonly Regex WhiteSpaceRegex = new Regex(@"\s+"); + + /// + /// The SQL provider creates temp tables with random names to unify the query if + /// it spans multiple databases. The regex below matches the temp table name so that + /// we can replace it with a masked table name. + /// + private static readonly Regex TempTableNameRegex = new Regex(@"\[\#unify_temptbl_[0-9a-fA-F]{8}\]"); + private static readonly string TempTableName = "[#unify_temptbl_XXXXXXXX]"; + + /// + /// Settings for the mock session + /// + private readonly MockSettings settings; + + /// + /// Collection of parameters for the command + /// + private readonly MockSqlParameterCollection parameterCollection; + + /// + /// Static C'tor. Initializes the mock results + /// + static MockSqlCommand() + { + InitializeMockResults(); + } + + /// + /// C'tor. + /// + /// The connection information + /// The mock settings + internal MockSqlCommand(DbConnection connection, MockSettings settings) + { + Assert.IsTrue(connection != null); + Assert.IsTrue(settings != null); + + this.DbConnection = connection; + this.parameterCollection = new MockSqlParameterCollection(); + this.settings = settings; + } + + /// + /// No-op + /// + public override void Cancel() + { + } + + /// + /// Gets or sets the command text + /// + public override string CommandText + { + get; + set; + } + + /// + /// Gets or sets the command timeout + /// + public override int CommandTimeout + { + get; + set; + } + + /// + /// Gets or sets the command type + /// + public override CommandType CommandType + { + get; + set; + } + + /// + /// Returns a new database parameter than can be used + /// + /// A db parameter instance + protected override DbParameter CreateDbParameter() + { + return new MockSqlParameter(); + } + + /// + /// Gets or sets the database connection + /// + protected override DbConnection DbConnection + { + get; + set; + } + + /// + /// Gets the parameter collection + /// + protected override DbParameterCollection DbParameterCollection + { + get + { + return this.parameterCollection; + } + } + + /// + /// Gets or sets the database transaction + /// + protected override DbTransaction DbTransaction + { + get; + set; + } + + /// + /// Gets or sets whether this is design time visible + /// + public override bool DesignTimeVisible + { + get; + set; + } + + /// + /// Executes the database data reader command using mock framework + /// + /// The command behaviour + /// A database data reader + protected override DbDataReader ExecuteDbDataReader(CommandBehavior behavior) + { + Assert.IsTrue((this.Connection.State & ConnectionState.Open) == ConnectionState.Open, "Connection has to be opened when executing a command"); + + string commandKey = this.GetCommandKey(); + MockQueryResult mockResult = FindMockResult(this.settings.MockId, this.Connection.Database, commandKey, this.settings.IsolatedQueries); + + if (mockResult == null && this.settings.RecordingMode) + { + mockResult = this.RecordExecuteDbDataReader(); + } + + if (mockResult == null || mockResult.DataSetResult == null) + { + if (mockResult != null && mockResult.ExceptionResult != null) + { + throw mockResult.ExceptionResult.Exception; + } + else + { + throw new NotSupportedException(string.Format("Mock SqlConnection does not know how to handle query: '{0}'", commandKey)); + } + } + + return mockResult.DataSetResult.DataSet.CreateDataReader(); + } + + /// + /// Executes a non query command + /// + /// 0 + public override int ExecuteNonQuery() + { + Assert.IsTrue((this.Connection.State & ConnectionState.Open) == ConnectionState.Open, "Connection has to be opened when executing command"); + + if (this.CommandText.StartsWith("USE ", StringComparison.OrdinalIgnoreCase)) + { + string databaseName = this.CommandText.Substring(4).Trim(); + + if (databaseName.StartsWith("[", StringComparison.OrdinalIgnoreCase) && databaseName.EndsWith("]", StringComparison.OrdinalIgnoreCase)) + { + databaseName = databaseName.Substring(1, databaseName.Length - 2).Replace("]]", "]"); + } + + this.Connection.ChangeDatabase(databaseName); + } + + //If recording actually run the command to affect the db + if(this.settings.RecordingMode) + { + using (SqlConnection connection = this.CreateSqlConnection()) + { + connection.Open(); + + SqlCommand cmd = connection.CreateCommand(); + cmd.CommandTimeout = 300; + cmd.CommandType = this.CommandType; + cmd.CommandText = this.CommandText; + + cmd.ExecuteNonQuery(); + } + } + + return 0; + } + + /// + /// Executes a scalar command against the mock framework + /// + /// The scalar result of the query, or null if none exist + public override object ExecuteScalar() + { + Assert.IsTrue((this.Connection.State & ConnectionState.Open) == ConnectionState.Open, "Connection has to be opened when executing command"); + + string commandKey = this.GetCommandKey(); + MockQueryResult mockResult = FindMockResult(this.settings.MockId, this.Connection.Database, commandKey, this.settings.IsolatedQueries); + + if (mockResult == null && this.settings.RecordingMode) + { + mockResult = this.RecordExecuteScalar(); + } + + return mockResult != null ? mockResult.ScalarResult : null; + } + + /// + /// No-op + /// + public override void Prepare() + { + } + + /// + /// Gets or sets the UpdateRowSource + /// + public override UpdateRowSource UpdatedRowSource + { + get; + set; + } + + #region Query Recording + + /// + /// Creates a sql connection + /// + /// + private SqlConnection CreateSqlConnection() + { + //SqlConnectionStringBuilder csb = new SqlConnectionStringBuilder(this.settings.SqlConnectionString); + SqlConnectionStringBuilder csb = new SqlConnectionStringBuilder(this.Connection.ConnectionString); + csb.InitialCatalog = this.Connection.Database; + csb["Encrypt"] = false; + string connectionString = csb.ToString(); + + return new SqlConnection(connectionString); + } + + /// + /// Records the results of calling execute scalar on the command + /// + /// The result of executing the command + private MockQueryResult RecordExecuteScalar() + { + MockQueryResult mockResult = new MockQueryResult(); + + using (SqlConnection connection = this.CreateSqlConnection()) + { + connection.Open(); + + SqlCommand cmd = connection.CreateCommand(); + cmd.CommandTimeout = 120; + cmd.CommandType = this.CommandType; + cmd.CommandText = this.CommandText; + + foreach (DbParameter param in this.Parameters) + { + SqlParameter sqlParam = new SqlParameter(param.ParameterName, param.DbType); + sqlParam.Value = param.Value; + cmd.Parameters.Add(sqlParam); + } + + mockResult.ScalarResult = cmd.ExecuteScalar(); + + mockResult.CommandText = this.GetCommandKey(); + mockResult.DatabaseName = this.DbConnection.Database; + mockResult.MockId = this.settings.MockId; + } + + this.SaveMockResult(mockResult); + + return mockResult; + } + + /// + /// Record the result of calling Execute database data reader + /// + /// The mock query results + private MockQueryResult RecordExecuteDbDataReader() + { + MockQueryResult mockResult = new MockQueryResult(); + + mockResult.CommandText = this.GetCommandKey(); + mockResult.DatabaseName = this.DbConnection.Database; + mockResult.MockId = this.settings.MockId; + + try + { + using (SqlConnection connection = this.CreateSqlConnection()) + { + connection.Open(); + + SqlCommand cmd = connection.CreateCommand(); + cmd.CommandTimeout = 120; + cmd.CommandType = this.CommandType; + cmd.CommandText = this.CommandText; + + foreach (DbParameter param in this.Parameters) + { + SqlParameter sqlParam = new SqlParameter(param.ParameterName, param.DbType); + sqlParam.Value = param.Value; + cmd.Parameters.Add(sqlParam); + } + + DataSet dataSet = new DataSet(); + + SqlDataAdapter adapter = new SqlDataAdapter(cmd); + adapter.Fill(dataSet); + + mockResult.DataSetResult = new MockDataSet(dataSet); + } + + this.SaveMockResult(mockResult); + + return mockResult; + } + catch (SqlException e) + { + // Record any exceptions generated + mockResult.ExceptionResult = new MockException(e); + this.SaveMockResult(mockResult); + + // Rethrow exception to caller + throw; + } + } + + /// + /// Save the mock results to a file. + /// + /// The results to save + private void SaveMockResult(MockQueryResult mockResult) + { + string[] parts = mockResult.MockId.Split(new char[] { '.' }); + + string fileName = Path.Combine(this.settings.OutputPath, parts[0] + ".xml"); + + MockQueryResultSet mockResultSet = null; + if (File.Exists(fileName)) + { + string fileText = File.ReadAllText(fileName).Trim(); + + if (fileText != String.Empty) + { + using (Stream stream = new MemoryStream(System.Text.UnicodeEncoding.UTF8.GetBytes(fileText))) + { + mockResultSet = MockQueryResultSet.Deserialize(stream); + } + } + } + + if (mockResultSet == null) + { + mockResultSet = new MockQueryResultSet(); + } + + string mockResultKey = NormalizeCommandText(mockResult.CommandText); + + int matchIdx = -1; + for (int idx = 0; idx < mockResultSet.CommandResults.Count; idx++) + { + MockQueryResult currentResult = mockResultSet.CommandResults[idx]; + + if (NormalizeCommandText(currentResult.CommandText) == mockResultKey && + currentResult.DatabaseName == mockResult.DatabaseName && + currentResult.MockId == mockResult.MockId) + { + matchIdx = idx; + break; + } + } + + if (matchIdx >= 0) + { + mockResultSet.CommandResults[matchIdx] = mockResult; + } + else + { + mockResultSet.CommandResults.Add(mockResult); + } + + using (Stream stream = File.Open(fileName, FileMode.Create, FileAccess.Write)) + { + MockQueryResultSet.Serialize(stream, mockResultSet); + } + + AddMockResult(mockResult); + } + + /// + /// Get a key based on the command text + /// + /// The command key + private string GetCommandKey() + { + string key = this.CommandText; + + // substitue parameter names by their values + foreach (DbParameter parameter in this.parameterCollection) + { + string value; + + switch (parameter.DbType) + { + case DbType.AnsiString: + case DbType.AnsiStringFixedLength: + value = (string)parameter.Value; + break; + case DbType.String: + case DbType.StringFixedLength: + value = (string)parameter.Value; + break; + case DbType.Boolean: + value = (bool)parameter.Value ? "1" : "0"; + break; + default: + value = parameter.Value.ToString(); + break; + } + + key = key.Replace(parameter.ParameterName, value); + } + + key = key.Replace("\r", string.Empty).Replace("\n", Environment.NewLine); + + key = TempTableNameRegex.Replace(key, TempTableName); + + return key; + } + + #endregion + + #region Query Execution Results + + /// + /// Normalizes the command text by removing excess white spaces + /// + /// + /// + private static string NormalizeCommandText(string commandText) + { + return WhiteSpaceRegex.Replace(commandText, " ").Trim(); + } + + /// + /// Adds a mock result to this list of results. + /// + /// the mock result to add + private static void AddMockResult(MockQueryResult mockResult) + { + List list; + + string key = NormalizeCommandText(mockResult.CommandText); + if (!mockResults.TryGetValue(key, out list)) + { + list = new List(); + mockResults.Add(key, list); + } + + list.Add(mockResult); + } + + /// + /// Find the rank of a match. + /// + /// The id of the candidate matcj + /// The id of the mock to compare. + /// whether or not is an isolated query + /// -1 for bad match, number of parts matched otherwize. + private static int GetMatchRank(string candidateId, string mockId, bool isolatedQuery) + { + if ((candidateId == null) || (mockId == null)) + return isolatedQuery ? -1 : 0; + + string[] mockParts = mockId.Split(new char[] { '.' }); + string[] candidateParts = candidateId.Split(new char[] { '.' }); + + if (candidateParts.Length > mockParts.Length) + { + return -1; + } + + for (int idx = 0; idx < candidateParts.Length; idx++) + { + if (candidateParts[idx] != mockParts[idx]) + { + return -1; + } + } + + return candidateParts.Length; + } + + /// + /// Find the mock results for the given input + /// + /// The mock id of the command + /// The database name being used + /// The command text being used + /// Whether or not it is an isolated query + /// null, or the matching query results + private static MockQueryResult FindMockResult(string mockId, string databaseName, string commandText, bool isolatedQuery) + { + Assert.IsNotNull(databaseName); + Assert.IsNotNull(commandText); + + string key = NormalizeCommandText(commandText); + + if (!mockResults.ContainsKey(key)) + { + return null; + } + + // Find all candidates, with matching mockId and databaseName. + // We prefer the candidate with exact match on mockId in first place and with exact match on databaseName in second place. + var candidates = + (from MockQueryResult mr in mockResults[key] + where (GetMatchRank(mr.MockId, mockId, isolatedQuery) >= 0) && + (mr.DatabaseName == databaseName || mr.DatabaseName == null) + orderby GetMatchRank(mr.MockId, mockId, isolatedQuery) descending, mr.DatabaseName == databaseName descending + select mr).ToArray(); + + // Return the best candidate. + return candidates.Length > 0 ? candidates[0] : null; + } + + /// + /// Used to initialize the mock results + /// + private static void InitializeMockResults() + { + string path = "./TSqlMockSessions"; + + var files = Directory.GetFiles(path); + foreach (var file in files) + { + using (FileStream stream = new FileStream(file, FileMode.Open)) + { + MockQueryResultSet mockResultSet = MockQueryResultSet.Deserialize(stream); + + if (mockResultSet.CommandResults != null) + { + foreach (MockQueryResult mockResult in mockResultSet.CommandResults) + { + AddMockResult(mockResult); + } + } + } + } + + //foreach (string rn in resourceNames) + //{ + // using (Stream stream = Assembly.GetExecutingAssembly().GetManifestResourceStream(rn)) + // { + // if (stream.Length > 0) + // { + // MockQueryResultSet mockResultSet = MockQueryResultSet.Deserialize(stream); + + // if (mockResultSet.CommandResults != null) + // { + // foreach (MockQueryResult mockResult in mockResultSet.CommandResults) + // { + // AddMockResult(mockResult); + // } + // } + // } + // } + //} + } + + #endregion + } +} diff --git a/src/ServiceManagement/Sql/Commands.SqlDatabase.Test/UnitTests/TSql/MockSqlConnection.cs b/src/ServiceManagement/Sql/Commands.SqlDatabase.Test/UnitTests/TSql/MockSqlConnection.cs new file mode 100644 index 000000000000..52b0c93d47c0 --- /dev/null +++ b/src/ServiceManagement/Sql/Commands.SqlDatabase.Test/UnitTests/TSql/MockSqlConnection.cs @@ -0,0 +1,223 @@ +// ---------------------------------------------------------------------------------- +// +// Copyright Microsoft Corporation +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ---------------------------------------------------------------------------------- + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using System.Data.SqlClient; +using System.Data.Common; +using System.Text.RegularExpressions; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using System.Data; + +namespace Microsoft.WindowsAzure.Commands.SqlDatabase.Test.UnitTests.TSql +{ + public class MockSqlConnection : DbConnection + { + private int connectionCount = 0; + private readonly object syncRoot = new object(); + private readonly MockSettings settings; + + /// + /// Constructor for the mock sql connection + /// + public MockSqlConnection() + { + this.settings = MockSettings.RetrieveSettings(); + } + + public static DbConnection CreateConnection(string connectionString) + { + MockSqlConnection conn = new MockSqlConnection(); + SqlConnectionStringBuilder csb = new SqlConnectionStringBuilder(connectionString); + csb["Encrypt"] = false; + conn.ConnectionString = csb.ConnectionString; + return conn; + } + + /// + /// Initializes the mock environment + /// + private void InitializeMockEnvironment() + { + //if (this.settings.RecordingMode && this.settings.InitializeMethod != null) + //{ + // using (SqlConnection connection = new SqlConnection(this.settings.SqlConnectionString)) + // { + // connection.Open(); + + // this.settings.InitializeMethod(connection); + // } + //} + } + + /// + /// Cleans up the mock environment. + /// + private void CleanupMockEnvironment() + { + //if (this.settings.RecordingMode && this.settings.CleanupMethod != null) + //{ + // using (SqlConnection connection = new SqlConnection(this.settings.SqlConnectionString)) + // { + // connection.Open(); + + // this.settings.CleanupMethod(connection); + // } + //} + } + + /// + /// Not supported + /// + /// Not used + /// Not Used + protected override DbTransaction BeginDbTransaction(System.Data.IsolationLevel isolationLevel) + { + throw new NotSupportedException(); + } + + /// + /// Used to change which database is being queried + /// + /// The name of the database to run the queries against + public override void ChangeDatabase(string databaseName) + { + if (!string.IsNullOrEmpty(this.ConnectionString)) + { + SqlConnectionStringBuilder csb = new SqlConnectionStringBuilder(this.ConnectionString); + csb.InitialCatalog = databaseName; + + this.ConnectionString = csb.ToString(); + } + } + + /// + /// Close the connection + /// + public override void Close() + { + lock (this.syncRoot) + { + this.connectionCount--; + + Assert.IsTrue(this.connectionCount >= 0, "Connection has been closed more times than opened. Check for correct pairing of Open/Close methods."); + + if (this.connectionCount == 0) + { + this.CleanupMockEnvironment(); + } + } + } + + /// + /// Gets or sets the sql connection string + /// + public override string ConnectionString + { + get; + set; + } + + /// + /// Creates a DB command for querying the database + /// + /// + protected override DbCommand CreateDbCommand() + { + return new MockSqlCommand(this, this.settings); + } + + /// + /// Gets the data source being queried + /// + public override string DataSource + { + get + { + if (!string.IsNullOrEmpty(this.ConnectionString)) + { + SqlConnectionStringBuilder csb = new SqlConnectionStringBuilder(this.ConnectionString); + return csb.DataSource; + } + else + { + return string.Empty; + } + } + } + + /// + /// Gets the name of the database being queried + /// + public override string Database + { + get + { + string database; + if (!string.IsNullOrEmpty(this.ConnectionString)) + { + SqlConnectionStringBuilder csb = new SqlConnectionStringBuilder(this.ConnectionString); + database = csb.InitialCatalog; + } + else + { + database = null; + } + + return !String.IsNullOrEmpty(database) ? database : "master"; + } + } + + /// + /// Opens a connection + /// + public override void Open() + { + lock (this.syncRoot) + { + if (this.connectionCount == 0) + { + this.InitializeMockEnvironment(); + } + + this.connectionCount++; + } + } + + /// + /// Returns the server version + /// + public override string ServerVersion + { + get + { + return "10.00.1600"; + } + } + + /// + /// Returns the state of the connection + /// + public override System.Data.ConnectionState State + { + get + { + return (0 < this.connectionCount) ? ConnectionState.Open : ConnectionState.Closed; + } + } + } +} diff --git a/src/ServiceManagement/Sql/Commands.SqlDatabase.Test/UnitTests/TSql/MockSqlParameter.cs b/src/ServiceManagement/Sql/Commands.SqlDatabase.Test/UnitTests/TSql/MockSqlParameter.cs new file mode 100644 index 000000000000..253bf0d7724d --- /dev/null +++ b/src/ServiceManagement/Sql/Commands.SqlDatabase.Test/UnitTests/TSql/MockSqlParameter.cs @@ -0,0 +1,115 @@ +// ---------------------------------------------------------------------------------- +// +// Copyright Microsoft Corporation +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ---------------------------------------------------------------------------------- + +using System; +using System.Collections.Generic; +using System.Data; +using System.Data.Common; +using System.Linq; +using System.Text; +using System.Threading.Tasks; + +namespace Microsoft.WindowsAzure.Commands.SqlDatabase.Test.UnitTests.TSql +{ + internal class MockSqlParameter : DbParameter + { + /// + /// Gets or sets the parameter type + /// + public override DbType DbType + { + get; + set; + } + + /// + /// Gets or sets the parameter direction (input, output, ...) + /// + public override ParameterDirection Direction + { + get; + set; + } + + /// + /// Gets or sets whether the value is nullable + /// + public override bool IsNullable + { + get; + set; + } + + /// + /// Gets or sets the parameter name + /// + public override string ParameterName + { + get; + set; + } + + /// + /// Resets the parameter type (no-op) + /// + public override void ResetDbType() + { + } + + /// + /// Gets or sets the size of the parameter + /// + public override int Size + { + get; + set; + } + + /// + /// Gets or sets the source column + /// + public override string SourceColumn + { + get; + set; + } + + /// + /// Gets or sets the source column null mapping + /// + public override bool SourceColumnNullMapping + { + get; + set; + } + + /// + /// Gets or sets the source version + /// + public override DataRowVersion SourceVersion + { + get; + set; + } + + /// + /// Gets or sets the value of the parameter + /// + public override object Value + { + get; + set; + } + } +} diff --git a/src/ServiceManagement/Sql/Commands.SqlDatabase.Test/UnitTests/TSql/MockSqlParameterCollection.cs b/src/ServiceManagement/Sql/Commands.SqlDatabase.Test/UnitTests/TSql/MockSqlParameterCollection.cs new file mode 100644 index 000000000000..2d3f80d40e81 --- /dev/null +++ b/src/ServiceManagement/Sql/Commands.SqlDatabase.Test/UnitTests/TSql/MockSqlParameterCollection.cs @@ -0,0 +1,269 @@ +// ---------------------------------------------------------------------------------- +// +// Copyright Microsoft Corporation +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ---------------------------------------------------------------------------------- + +using System; +using System.Collections.Generic; +using System.Data.Common; +using System.Linq; +using System.Text; +using System.Threading.Tasks; + +namespace Microsoft.WindowsAzure.Commands.SqlDatabase.Test.UnitTests.TSql +{ + /// + /// Represents a collection of parameters for the sql statement + /// + internal class MockSqlParameterCollection : DbParameterCollection + { + /// + /// Internal storage of parameters + /// + private List internalCollection = new List(); + + /// + /// Adds a parameter to the list. Must be of type MockSqlParameterd + /// + /// The parameter value + /// The number of parameters in the collection + public override int Add(object value) + { + this.internalCollection.Add((MockSqlParameter)value); + return (this.internalCollection.Count - 1); + } + + /// + /// Adds a range of values. Must be of type MockSqlParameter + /// + /// An array of parameter values + public override void AddRange(Array values) + { + foreach (MockSqlParameter parameter in values) + { + this.internalCollection.Add(parameter); + } + } + + /// + /// Clears the collection of all parameters + /// + public override void Clear() + { + this.internalCollection.Clear(); + } + + /// + /// Checks to see if a value exists in the collection + /// + /// The value to search for + /// True if the value exists in the collection + public override bool Contains(string value) + { + return (-1 != this.IndexOf(value)); + } + + /// + /// Checks to see if a value exists in the collection + /// + /// The value to search for + /// True if the value exists in the collection + public override bool Contains(object value) + { + return (-1 != this.IndexOf(value)); + } + + /// + /// Copies the contents of array to the internal parameter collection. + /// + /// The parameters to copy + /// Where to insert the array in the internal collection + public override void CopyTo(Array array, int index) + { + this.internalCollection.CopyTo((MockSqlParameter[])array, index); + } + + /// + /// Gets how many parameters are in the collection + /// + public override int Count + { + get { return this.internalCollection.Count(); } + } + + /// + /// Gets an enumerator over the parameter collection + /// + /// + public override System.Collections.IEnumerator GetEnumerator() + { + return this.internalCollection.GetEnumerator(); + } + + /// + /// Get a parameter by name + /// + /// The name of the parameter to retrieve + /// The parameter object or IndexOutOfRangeException + protected override DbParameter GetParameter(string parameterName) + { + int index = this.IndexOf(parameterName); + if (index < 0) + { + throw new IndexOutOfRangeException(); + } + return this.internalCollection[index]; + } + + /// + /// Gets the parameter by index + /// + /// The index of the parameter to retrieve + /// The parameter object + protected override DbParameter GetParameter(int index) + { + return this.internalCollection[index]; + } + + /// + /// Gets the index of a parameter + /// + /// The name of the parameter + /// The index of the parameter with the given name + public override int IndexOf(string parameterName) + { + for (int i = 0; i < this.internalCollection.Count; ++i) + { + if (parameterName == this.internalCollection[i].ParameterName) + { + return i; + } + } + return -1; + } + + /// + /// Gets the index of the parameter + /// + /// The parameter to find the index of + /// The index of the parameter + public override int IndexOf(object value) + { + return this.internalCollection.IndexOf((MockSqlParameter)value); + } + + /// + /// Adds a parameter at a given location + /// + /// Where to insert the parameter + /// The parameter to insert + public override void Insert(int index, object value) + { + this.internalCollection.Insert(index, (MockSqlParameter)value); + } + + /// + /// Gets whether or not the parameter collection is of fixed size + /// + public override bool IsFixedSize + { + get { return false; } + } + + /// + /// Gets whether the parameter collection is fixed size or not + /// + public override bool IsReadOnly + { + get { return false; } + } + + /// + /// Gets whether or not the parameter collection is synchronized + /// + public override bool IsSynchronized + { + get { return false; } + } + + /// + /// Removes a parameter from the collection + /// + /// The parameter to remove + public override void Remove(object value) + { + int index = this.IndexOf(value); + if (index < 0) + { + throw new ArgumentOutOfRangeException(); + } + RemoveAt(index); + } + + /// + /// Remote a parameter with given name + /// + /// The name of the parameter to remove + public override void RemoveAt(string parameterName) + { + int index = this.IndexOf(parameterName); + if (index < 0) + { + throw new ArgumentOutOfRangeException(); + } + RemoveAt(index); + } + + /// + /// Remove a parameter at given index + /// + /// The index of the parameter to remove + public override void RemoveAt(int index) + { + this.internalCollection.RemoveAt(index); + } + + /// + /// Change the value of a parameter. + /// + /// The name of the parameter to change + /// The new value + protected override void SetParameter(string parameterName, DbParameter value) + { + int index = this.IndexOf(parameterName); + if (index < 0) + { + throw new ArgumentOutOfRangeException(); + } + + this.internalCollection[index].Value = value; + } + + /// + /// Change the value of a parameter + /// + /// The index of the parameter to change + /// The new value of the parameter + protected override void SetParameter(int index, DbParameter value) + { + this.internalCollection[index].Value = value; + } + + /// + /// Object for syncronization + /// + public override object SyncRoot + { + get { return null; } + } + } +} diff --git a/src/ServiceManagement/Sql/Commands.SqlDatabase.Test/UnitTests/TSql/RecordMockDataResultsAttribute.cs b/src/ServiceManagement/Sql/Commands.SqlDatabase.Test/UnitTests/TSql/RecordMockDataResultsAttribute.cs new file mode 100644 index 000000000000..d0c4c72c5411 --- /dev/null +++ b/src/ServiceManagement/Sql/Commands.SqlDatabase.Test/UnitTests/TSql/RecordMockDataResultsAttribute.cs @@ -0,0 +1,71 @@ +// ---------------------------------------------------------------------------------- +// +// Copyright Microsoft Corporation +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ---------------------------------------------------------------------------------- + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; + +namespace Microsoft.WindowsAzure.Commands.SqlDatabase.Test.UnitTests.TSql +{ + [AttributeUsage(AttributeTargets.Method | AttributeTargets.Class | AttributeTargets.Assembly)] + public class RecordMockDataResultsAttribute : Attribute + { + private readonly string outputPath; + private readonly bool isolatedQueries; + + /// + /// Creates the RecordMockDataResultsAttribute. + /// + /// The output directory where the captured results will be saved. + public RecordMockDataResultsAttribute(string outputPath) + : this(outputPath, false) + { + } + + /// + /// Creates the RecordMockDataResultsAttribute. + /// + /// The output directory where the captured results will be saved. + /// Name/address of the server. + /// User name. + /// Password + /// Initial database name. + /// Specifies whether the query capture should be in isolated mode. + /// That is shared query results will not be accessible. + public RecordMockDataResultsAttribute(string outputPath, bool isolatedQueries) + { + this.outputPath = outputPath; + this.isolatedQueries = isolatedQueries; + } + + /// + /// The output path for the query results + /// + public string OutputPath + { + get { return this.outputPath; } + } + + /// + /// Gets whether or not the query capture should be in isolated mode. + /// That is shared query results will not be accessible. + /// + public bool IsolatedQueries + { + get { return this.isolatedQueries; } + } + } +} diff --git a/src/ServiceManagement/Sql/Commands.SqlDatabase.Test/UnitTests/UnitTestHelper.cs b/src/ServiceManagement/Sql/Commands.SqlDatabase.Test/UnitTests/UnitTestHelper.cs index 0ec0c4d6c2b1..5129e1576b9f 100644 --- a/src/ServiceManagement/Sql/Commands.SqlDatabase.Test/UnitTests/UnitTestHelper.cs +++ b/src/ServiceManagement/Sql/Commands.SqlDatabase.Test/UnitTests/UnitTestHelper.cs @@ -270,6 +270,8 @@ public static void CreateTestCredential(System.Management.Automation.PowerShell /// An instance of the object. public static void CreateTestCredential(System.Management.Automation.PowerShell powershell, string username, string password) { + password = password.Replace("$", "`$"); + // Create the test credential powershell.InvokeBatchScript( string.Format(@"$user = ""{0}""", username), diff --git a/src/ServiceManagement/Sql/Commands.SqlDatabase/Commands.SqlDatabase.csproj b/src/ServiceManagement/Sql/Commands.SqlDatabase/Commands.SqlDatabase.csproj index 75d1ed611600..e33ac6d52db3 100644 --- a/src/ServiceManagement/Sql/Commands.SqlDatabase/Commands.SqlDatabase.csproj +++ b/src/ServiceManagement/Sql/Commands.SqlDatabase/Commands.SqlDatabase.csproj @@ -190,6 +190,7 @@ + diff --git a/src/ServiceManagement/Sql/Commands.SqlDatabase/Database/Cmdlet/NewAzureSqlDatabase.cs b/src/ServiceManagement/Sql/Commands.SqlDatabase/Database/Cmdlet/NewAzureSqlDatabase.cs index 2c98967d706b..8f2e8c97345f 100644 --- a/src/ServiceManagement/Sql/Commands.SqlDatabase/Database/Cmdlet/NewAzureSqlDatabase.cs +++ b/src/ServiceManagement/Sql/Commands.SqlDatabase/Database/Cmdlet/NewAzureSqlDatabase.cs @@ -177,7 +177,7 @@ private void ProcessWithServerName(int? maxSizeGb, long? maxSizeBytes) ServerDataServiceCertAuth.Create(this.ServerName, subscription); GetClientRequestId = () => context.ClientRequestId; - + Services.Server.Database response = context.CreateNewDatabase( this.DatabaseName, maxSizeGb, @@ -185,7 +185,7 @@ private void ProcessWithServerName(int? maxSizeGb, long? maxSizeBytes) this.Collation, this.Edition, this.ServiceObjective); - + response = CmdletCommon.WaitForDatabaseOperation(this, context, response, this.DatabaseName, true); // Retrieve the database with the specified name diff --git a/src/ServiceManagement/Sql/Commands.SqlDatabase/Database/Cmdlet/NewAzureSqlDatabaseServerContext.cs b/src/ServiceManagement/Sql/Commands.SqlDatabase/Database/Cmdlet/NewAzureSqlDatabaseServerContext.cs index 01dd933b10eb..f7b5a0ec8abf 100644 --- a/src/ServiceManagement/Sql/Commands.SqlDatabase/Database/Cmdlet/NewAzureSqlDatabaseServerContext.cs +++ b/src/ServiceManagement/Sql/Commands.SqlDatabase/Database/Cmdlet/NewAzureSqlDatabaseServerContext.cs @@ -140,6 +140,13 @@ public class NewAzureSqlDatabaseServerContext : AzurePSCmdlet HelpMessage = "The subscription to use, or uses the current subscription if not specified")] public string SubscriptionName { get; set; } + /// + /// Switch to indiciate the the server is an ESA server + /// + [Parameter(Mandatory = false, + HelpMessage = "Indicates the server version being targeted. Valid values [2.0, 12.0]. Default = 2.0")] + public float Version { get; set; } + #endregion #region Current Subscription Management @@ -172,43 +179,72 @@ private AzureSubscription CurrentSubscription /// The SQL Authentication credentials for the server. /// A new context, /// or null if an error occurred. - internal ServerDataServiceSqlAuth GetServerDataServiceBySqlAuth( + internal IServerDataServiceContext GetServerDataServiceBySqlAuth( string serverName, Uri managementServiceUri, - SqlAuthenticationCredentials credentials) + SqlAuthenticationCredentials credentials, + Uri manageUrl) { - ServerDataServiceSqlAuth context = null; - + IServerDataServiceContext context = null; Guid sessionActivityId = Guid.NewGuid(); - try + + if (this.MyInvocation.BoundParameters.ContainsKey("Version")) { - context = ServerDataServiceSqlAuth.Create( - managementServiceUri, - sessionActivityId, - credentials, - serverName); - - // Retrieve $metadata to verify model version compatibility - XDocument metadata = context.RetrieveMetadata(); - XDocument filteredMetadata = DataConnectionUtility.FilterMetadataDocument(metadata); - string metadataHash = DataConnectionUtility.GetDocumentHash(filteredMetadata); - if (!context.metadataHashes.Any(knownHash => metadataHash == knownHash)) + if (this.Version == 12.0f) { - this.WriteWarning(Resources.WarningModelOutOfDate); + try + { + context = new TSqlConnectionContext( + sessionActivityId, + manageUrl.Host, + credentials.UserName, + credentials.Password); + } + catch (Exception ex) + { + SqlDatabaseExceptionHandler.WriteErrorDetails( + this, + sessionActivityId.ToString(), + ex); + + // The context is not in an valid state because of the error, set the context + // back to null. + context = null; + } + } + else + { + try + { + context = ServerDataServiceSqlAuth.Create( + managementServiceUri, + sessionActivityId, + credentials, + serverName); + + // Retrieve $metadata to verify model version compatibility + XDocument metadata = ((ServerDataServiceSqlAuth)context).RetrieveMetadata(); + XDocument filteredMetadata = DataConnectionUtility.FilterMetadataDocument(metadata); + string metadataHash = DataConnectionUtility.GetDocumentHash(filteredMetadata); + if (!((ServerDataServiceSqlAuth)context).metadataHashes.Any(knownHash => metadataHash == knownHash)) + { + this.WriteWarning(Resources.WarningModelOutOfDate); + } + + ((ServerDataServiceSqlAuth)context).MergeOption = MergeOption.PreserveChanges; + } + catch (Exception ex) + { + SqlDatabaseExceptionHandler.WriteErrorDetails( + this, + sessionActivityId.ToString(), + ex); + + // The context is not in an valid state because of the error, set the context + // back to null. + context = null; + } } - - context.MergeOption = MergeOption.PreserveChanges; - } - catch (Exception ex) - { - SqlDatabaseExceptionHandler.WriteErrorDetails( - this, - sessionActivityId.ToString(), - ex); - - // The context is not in an valid state because of the error, set the context - // back to null. - context = null; } return context; @@ -250,7 +286,8 @@ internal ServerDataServiceCertAuth GetServerDataServiceByCertAuth( /// A new operation context for the server. internal IServerDataServiceContext CreateServerDataServiceContext( string serverName, - Uri managementServiceUri) + Uri managementServiceUri, + Uri manageUrl) { switch (this.ParameterSetName) { @@ -262,7 +299,8 @@ internal IServerDataServiceContext CreateServerDataServiceContext( return this.GetServerDataServiceBySqlAuth( serverName, managementServiceUri, - credentials); + credentials, + manageUrl); case FullyQualifiedServerNameWithCertAuthParamSet: case ServerNameWithCertAuthParamSet: @@ -293,7 +331,7 @@ public override void ExecuteCmdlet() // Creates a new Server Data Service Context for the service IServerDataServiceContext operationContext = - this.CreateServerDataServiceContext(serverName, managementServiceUri); + this.CreateServerDataServiceContext(serverName, managementServiceUri, manageUrl); if (operationContext != null) { diff --git a/src/ServiceManagement/Sql/Commands.SqlDatabase/Services/Server/ServerDataServiceSqlAuth.cs b/src/ServiceManagement/Sql/Commands.SqlDatabase/Services/Server/ServerDataServiceSqlAuth.cs index 4b321b26e1bf..a4cd18328169 100644 --- a/src/ServiceManagement/Sql/Commands.SqlDatabase/Services/Server/ServerDataServiceSqlAuth.cs +++ b/src/ServiceManagement/Sql/Commands.SqlDatabase/Services/Server/ServerDataServiceSqlAuth.cs @@ -338,7 +338,7 @@ public Database CreateNewDatabase( { database.MaxSizeGB = (int)databaseMaxSizeGb; } - if(databaseMaxSizeBytes != null) + if (databaseMaxSizeBytes != null) { database.MaxSizeBytes = (long)databaseMaxSizeBytes; } diff --git a/src/ServiceManagement/Sql/Commands.SqlDatabase/Services/Server/TSqlConnectionContext.cs b/src/ServiceManagement/Sql/Commands.SqlDatabase/Services/Server/TSqlConnectionContext.cs new file mode 100644 index 000000000000..b7cb652ad20f --- /dev/null +++ b/src/ServiceManagement/Sql/Commands.SqlDatabase/Services/Server/TSqlConnectionContext.cs @@ -0,0 +1,886 @@ +// ---------------------------------------------------------------------------------- +// +// Copyright Microsoft Corporation +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ---------------------------------------------------------------------------------- + +using System; +using System.Collections.Generic; +using System.Data.Common; +using System.Data.SqlClient; +using System.Linq; +using System.Text; +using System.Threading.Tasks; + +namespace Microsoft.WindowsAzure.Commands.SqlDatabase.Services.Server +{ + public class TSqlConnectionContext : IServerDataServiceContext + { + /// + /// Timeout duration for commands + /// + private static int connectionTimeout = 60; + + /// + /// Set this to override the SQL Connection with a mock version + /// + public static object MockSqlConnection = null; + + /// + /// Query for retrieving database info + /// + private const string getDatabaseQuery = @" + SELECT + [db].[name], + [db].[database_id], + [db].[create_date], + [db].[collation_name], + [db].[is_read_only], + [db].[is_query_store_on], + [db].[is_recursive_triggers_on], + [db].[is_federation_member], + CONVERT (bit, CASE WHEN [db].[name] in ('master') THEN 1 ELSE [db].[is_distributor] END) AS [is_system_object], + CONVERT (int, + CASE + WHEN [db].[is_in_standby] = 1 THEN 0x0040 + ELSE 0 + END | + CASE + WHEN [db].[is_cleanly_shutdown] = 1 THEN 0x0080 + ELSE 0 + END | + CASE [db].[state] + WHEN 0 THEN 0x0001 -- NORMAL + WHEN 1 THEN 0x0002 -- RESTORING + WHEN 2 THEN 0x0008 -- RECOVERING + WHEN 3 THEN 0x0004 -- RECOVERY_PENDING + WHEN 4 THEN 0x0010 -- SUSPECT + WHEN 5 THEN 0x0100 -- EMERGENCY + WHEN 6 THEN 0x0020 -- OFFLINE + WHEN 7 THEN 0x0400 -- COPYING + WHEN 9 THEN 0x0800 -- CREATING + WHEN 10 THEN 0x1000 -- OFFLINE_SECONDARY + ELSE 0x0010 -- SUSPECT + END) AS [status] + FROM [sys].[databases] AS [db] + WHERE ([db].[name] = @name OR @name IS NULL)"; + + /// + /// Gets the session activity Id associated with this context. + /// + public Guid SessionActivityId + { + get + { + return this.sessionActivityId; + } + } + + /// + /// Gets the client per session tracing Id. + /// + public string ClientSessionId + { + get + { + return SqlDatabaseCmdletBase.clientSessionId; + } + } + + /// + /// Gets the previous request's client request Id. + /// + public string ClientRequestId + { + get + { + return this.clientRequestId; + } + } + + /// + /// Gets the name of the server for this context. + /// + public string ServerName + { + get + { + return this.serverName; + } + } + + /// + /// Contains the connection string necessary to connect to the server + /// + private SqlConnectionStringBuilder builder; + + /// + /// Unique session ID + /// + private Guid sessionActivityId; + + /// + /// Unique client request ID + /// + private string clientRequestId; + + /// + /// Server name for the context + /// + private string serverName; + + /// + /// Helper function to generate the SqlConnectionStringBuilder + /// + /// The fully qualified server name (eg: server1.database.windows.net) + /// The login username + /// The login password + /// A connection string builder + private SqlConnectionStringBuilder GenerateSqlConnectionBuilder(string fullyQualifiedServerName, string username, string password) + { + this.serverName = fullyQualifiedServerName.Split('.').First(); + SqlConnectionStringBuilder builder = new SqlConnectionStringBuilder(); + builder["Server"] = fullyQualifiedServerName; + builder.UserID = username + "@" + serverName; + builder.Password = password; + builder["Database"] = null; + builder["Encrypt"] = true; + builder.ConnectTimeout = connectionTimeout; + + return builder; + } + + /// + /// Creates a sql connection using the sql connection string builder. + /// + /// + private DbConnection CreateConnection() + { + if (MockSqlConnection != null) + { + ((DbConnection)MockSqlConnection).ConnectionString = builder.ConnectionString; + return (DbConnection)MockSqlConnection; + } + else + { + return new SqlConnection(builder.ConnectionString); + } + } + + /// + /// Creates an instance of a SQLAuth to TSql class + /// + /// + /// + /// + public TSqlConnectionContext(Guid sessionActivityId, string fullyQualifiedServerName, string username, string password) + { + this.sessionActivityId = sessionActivityId; + this.clientRequestId = SqlDatabaseCmdletBase.GenerateClientTracingId(); + builder = GenerateSqlConnectionBuilder(fullyQualifiedServerName, username, password); + } + + /// + /// Retrieves the list of all databases on the server. + /// + /// An array of all databases on the server. + public Database[] GetDatabases() + { + List databases = new List(); + + builder["Database"] = null; + using (var connection = CreateConnection()) + { + using (DbCommand command = connection.CreateCommand()) + { + command.CommandTimeout = connectionTimeout; + command.CommandText = getDatabaseQuery; + DbParameter param = command.CreateParameter(); + param.ParameterName = "@name"; + param.Value = DBNull.Value; + command.Parameters.Add(param); + + connection.Open(); + + using (DbDataReader reader = command.ExecuteReader()) + { + if (reader.HasRows) + { + while (reader.Read()) + { + databases.Add(PopulateDatabaseFromReader(reader)); + } + } + else + { + return null; + } + } + } + } + + databases.ForEach(db => GetDatabaseProperties(db)); + databases.ForEach(db => db.ServiceObjective = GetServiceObjective(db.ServiceObjectiveName)); + + databases.ForEach(db => db.Context = this); + + return databases.ToArray(); + } + + /// + /// Retrieve information on database with the name . + /// + /// The database to retrieve. + /// An object containing the information about the specific database. + public Database GetDatabase(string databaseName) + { + Database db = null; + + builder["Database"] = null; + using (var connection = CreateConnection()) + { + using (DbCommand command = connection.CreateCommand()) + //using (SqlCommand command = new SqlCommand(getDatabaseQuery, connection)) + { + command.CommandTimeout = connectionTimeout; + command.CommandText = getDatabaseQuery; + DbParameter param = command.CreateParameter(); + param.ParameterName = "@name"; + param.Value = databaseName; + command.Parameters.Add(param); + + connection.Open(); + + //using (SqlDataReader reader = command.ExecuteReader()) + using (DbDataReader reader = command.ExecuteReader()) + { + if (reader.HasRows) + { + while (reader.Read()) + { + db = PopulateDatabaseFromReader(reader); + } + } + else + { + return null; + } + } + } + } + + //db.ServiceObjectiveName; + //db.serviceobjectiveid + //db.SLOAssignment* + + //db.Server; + //db.ServiceObjective; + //db.DatabaseCopies; + //db.DatabaseMetrics; + + GetDatabaseProperties(db); + db.ServiceObjective = GetServiceObjective(db.ServiceObjectiveName); + db.Context = this; + + return db; + } + + /// + /// Creates a new Sql Database. + /// + /// The name for the new database. + /// The max size for the database in GB. + /// The max size for the database in bytes. + /// The collation for the database. + /// The edition for the database. + /// The service object to assign to the database + /// The newly created Sql Database. + public Database CreateNewDatabase( + string databaseName, + int? databaseMaxSizeGb, + long? databaseMaxSizeBytes, + string databaseCollation, + DatabaseEdition databaseEdition, + ServiceObjective serviceObjective) + { + return CreateNewDatabase( + databaseName, + databaseMaxSizeGb, + databaseMaxSizeBytes, + databaseCollation, + databaseEdition, + serviceObjective == null ? null : serviceObjective.Name); + } + + /// + /// Creates a new Sql Database. + /// + /// The name for the new database. + /// The max size for the database in GB. + /// The max size for the database in bytes. + /// The collation for the database. + /// The edition for the database. + /// The service object to assign to the database + /// The newly created Sql Database. + public Database CreateNewDatabase( + string databaseName, + int? databaseMaxSizeGb, + long? databaseMaxSizeBytes, + string databaseCollation, + DatabaseEdition databaseEdition, + string serviceObjectiveName) + { + builder["Database"] = null; + + string commandText = "CREATE DATABASE [{0}] "; + + if (!string.IsNullOrEmpty(databaseCollation)) + { + commandText += " COLLATE {1} "; + } + + List arguments = new List(); + + if (databaseMaxSizeGb != null || databaseMaxSizeBytes != null) + { + arguments.Add(" MAXSIZE={2} "); + } + + if (databaseEdition != DatabaseEdition.None) + { + arguments.Add(" EDITION='{3}' "); + } + + if (!string.IsNullOrEmpty(serviceObjectiveName)) + { + arguments.Add(" SERVICE_OBJECTIVE='{4}' "); + } + + if (arguments.Count > 0) + { + commandText += "(" + string.Join(", ", arguments.ToArray()) + ")"; + } + + string maxSizeVal = string.Empty; + if (databaseMaxSizeGb != null) + { + maxSizeVal = databaseMaxSizeGb.Value.ToString() + "GB"; + } + else if (databaseMaxSizeBytes != null) + { + if (databaseMaxSizeBytes > (500 * 1024 * 1024)) + { + maxSizeVal = (databaseMaxSizeBytes / (1024 * 1024 * 1024)).ToString() + "GB"; + } + else + { + maxSizeVal = (databaseMaxSizeBytes / (1024 * 1024)).ToString() + "MB"; + } + } + + commandText = string.Format( + commandText, + SqlEscape(databaseName), + SqlEscape(databaseCollation), + SqlEscape(maxSizeVal), + SqlEscape(databaseEdition.ToString()), + SqlEscape(serviceObjectiveName)); + + builder["Database"] = null; + using (var connection = CreateConnection()) + { + using (DbCommand command = connection.CreateCommand()) + { + command.CommandTimeout = connectionTimeout; + command.CommandText = commandText; + connection.Open(); + + command.ExecuteNonQuery(); + } + } + + return GetDatabase(databaseName); + } + + /// + /// Updates the property on the database with the name . + /// + /// The sql connection string information + /// The database to update. + /// The new database name, or null to not update. + /// The max size for the database in GB. + /// The max size for the database in bytes. + /// The new database edition, or null to not update. + /// The new service objective, or null to not update. + /// The updated database object. + public Database UpdateDatabase( + string databaseName, + string newDatabaseName, + int? databaseMaxSizeGb, + long? databaseMaxSizeBytes, + DatabaseEdition? databaseEdition, + ServiceObjective serviceObjective) + { + return UpdateDatabase( + databaseName, + newDatabaseName, + databaseMaxSizeGb, + databaseMaxSizeBytes, + databaseEdition, + serviceObjective == null ? null : serviceObjective.Name); + } + + /// + /// Updates the property on the database with the name . + /// + /// The sql connection string information + /// The database to update. + /// The new database name, or null to not update. + /// The max size for the database in GB. + /// The max size for the database in bytes. + /// The new database edition, or null to not update. + /// The new service objective name, or null to not update. + /// The updated database object. + public Database UpdateDatabase( + string databaseName, + string newDatabaseName, + int? databaseMaxSizeGb, + long? databaseMaxSizeBytes, + DatabaseEdition? databaseEdition, + string serviceObjectiveName) + { + Database result = null; + + if (!string.IsNullOrEmpty(newDatabaseName)) + { + result = AlterDatabaseName(databaseName, newDatabaseName); + databaseName = newDatabaseName; + } + + if (databaseMaxSizeBytes.HasValue || + databaseMaxSizeGb.HasValue || + databaseEdition.HasValue || + !string.IsNullOrEmpty(serviceObjectiveName)) + { + string sizeVal = null; + if (databaseMaxSizeGb.HasValue) + { + sizeVal = databaseMaxSizeGb.Value.ToString() + "GB"; + } + else if (databaseMaxSizeBytes.HasValue) + { + if (databaseMaxSizeBytes.Value > 500 * 1024 * 1024) + { + sizeVal = (databaseMaxSizeBytes.Value / (1024 * 1024 * 1024)).ToString() + "GB"; + } + else + { + sizeVal = (databaseMaxSizeBytes.Value / (1024 * 1024)).ToString() + "MB"; + } + } + + result = AlterDatabaseProperties(databaseName, sizeVal, databaseEdition, serviceObjectiveName); + } + + result.Context = this; + + return result; + } + + /// + /// Removes the database with the name . + /// + /// The sql connection string information + /// The database to remove. + public void RemoveDatabase(string databaseName) + { + string commandText = "DROP DATABASE [{0}]"; + + commandText = string.Format(commandText, SqlEscape(databaseName)); + + builder["Database"] = null; + using (var connection = CreateConnection()) + { + using (DbCommand command = connection.CreateCommand()) + { + connection.Open(); + command.CommandText = commandText; + + command.ExecuteNonQuery(); + } + } + } + + /// + /// Gets a list of all the available service objectives + /// + /// An array of service objectives + public ServiceObjective[] GetServiceObjectives() + { + ServiceObjective[] list = new[] + { + new ServiceObjective() + { + Context = this, + Enabled = true, + Id = new Guid("dd6d99bb-f193-4ec1-86f2-43d3bccbc49c "), + IsDefault = true, + IsSystem = false, + Name = "Basic" + }, + new ServiceObjective() + { + Context = this, + Enabled = true, + Id = new Guid("f1173c43-91bd-4aaa-973c-54e79e15235b "), + IsDefault = false, + IsSystem = false, + Name = "S0" + }, + new ServiceObjective() + { + Context = this, + Enabled = true, + Id = new Guid("1b1ebd4d-d903-4baa-97f9-4ea675f5e928 "), + IsDefault = false, + IsSystem = false, + Name = "S1" + }, + new ServiceObjective() + { + Context = this, + Enabled = true, + Id = new Guid("455330e1-00cd-488b-b5fa-177c226f28b7"), + IsDefault = false, + IsSystem = false, + Name = "S2" + }, + new ServiceObjective() + { + Context = this, + Enabled = true, + Id = new Guid("7203483a-c4fb-4304-9e9f-17c71c904f5d "), + IsDefault = false, + IsSystem = false, + Name = "P1" + }, + new ServiceObjective() + { + Context = this, + Enabled = true, + Id = new Guid("a7d1b92d-c987-4375-b54d-2b1d0e0f5bb0 "), + IsDefault = false, + IsSystem = false, + Name = "P2" + }, + new ServiceObjective() + { + Context = this, + Enabled = true, + Id = new Guid("a7c4c615-cfb1-464b-b252-925be0a19446"), + IsDefault = false, + IsSystem = false, + Name = "P3" + }, + }; + + return list; + } + + /// + /// Gets a service objective by name + /// + /// The name of the service objective to retrieve + /// A service objective + public ServiceObjective GetServiceObjective(string serviceObjectiveName) + { + return GetServiceObjectives().Where((slo) => slo.Name == serviceObjectiveName).FirstOrDefault(); + } + + public ServiceObjective GetServiceObjective(ServiceObjective serviceObjective) + { + return GetServiceObjective(serviceObjective.Name); + } + + public ServerQuota GetQuota(string quotaName) + { + return null; + } + + public ServerQuota[] GetQuotas() + { + return null; + } + + public DatabaseOperation GetDatabaseOperation(Guid OperationGuid) + { + throw new NotImplementedException(); + } + + public DatabaseOperation[] GetDatabaseOperations(string databaseName) + { + throw new NotImplementedException(); + } + + public DatabaseOperation[] GetDatabasesOperations() + { + throw new NotImplementedException(); + } + + public Model.DatabaseCopy[] GetDatabaseCopy(string databaseName, string partnerServer, string partnerDatabaseName) + { + throw new NotImplementedException(); + } + + public Model.DatabaseCopy GetDatabaseCopy(Model.DatabaseCopy databaseCopy) + { + throw new NotImplementedException(); + } + + public Model.DatabaseCopy StartDatabaseCopy(string databaseName, string partnerServer, string partnerDatabaseName, bool continuousCopy, bool isOfflineSecondary) + { + throw new NotImplementedException(); + } + + public void StopDatabaseCopy(Model.DatabaseCopy databaseCopy, bool forcedTermination) + { + throw new NotImplementedException(); + } + + public RestorableDroppedDatabase[] GetRestorableDroppedDatabases() + { + throw new NotImplementedException(); + } + + public RestorableDroppedDatabase GetRestorableDroppedDatabase(string databaseName, DateTime deletionDate) + { + throw new NotImplementedException(); + } + + public RestoreDatabaseOperation RestoreDatabase(string sourceDatabaseName, DateTime? sourceDatabaseDeletionDate, string targetServerName, string targetDatabaseName, DateTime? pointInTime) + { + throw new NotImplementedException(); + } + + #region Helpers + + /// + /// Checks if a value is null or DBNull and returns default(T). Otherwise returns the value + /// casted to the desired type. + /// + /// The type to cast to + /// The object to cast + /// The result + private T ConvertFromDbValue(object obj) + { + if (obj == null || Convert.IsDBNull(obj)) + { + return default(T); + } + else + { + return (T)obj; + } + } + + /// + /// Given a SqlDataReader extracts the necessary information to populate a database object + /// + /// The reader created from the GetDatabaseQuery + /// The new database + private Database PopulateDatabaseFromReader(DbDataReader reader) + { + Database db = new Database(); + db.Name = ConvertFromDbValue(reader["name"]); + db.CollationName = ConvertFromDbValue(reader["collation_name"]); + db.CreationDate = ConvertFromDbValue(reader["create_date"]); + db.Id = ConvertFromDbValue(reader["database_id"]); + db.IsFederationMember = ConvertFromDbValue(reader["is_federation_member"]); + db.IsFederationRoot = false; + db.IsQueryStoreOn = ConvertFromDbValue(reader["is_query_store_on"]); + db.IsQueryStoreReadOnly = false; + db.IsReadOnly = ConvertFromDbValue(reader["is_read_only"]); + db.IsRecursiveTriggersOn = ConvertFromDbValue(reader["is_recursive_triggers_on"]); + db.IsSuspended = false; + db.IsSystemObject = ConvertFromDbValue(reader["is_system_object"]); + db.QueryStoreClearAll = null; + db.QueryStoreFlushPeriodSeconds = null; + db.QueryStoreIntervalLengthMinutes = null; + db.QueryStoreMaxSizeMB = null; + db.QueryStoreStaleQueryThresholdDays = null; + db.RecoveryPeriodStartDate = null; + db.Status = ConvertFromDbValue(reader["status"]); + return db; + } + + /// + /// Gets some additional database properties (edition, maxsizebytes, ...) from the database + /// + /// + private void GetDatabaseProperties(Database db) + { + if (!string.IsNullOrEmpty(db.Name)) + { + string commandText = + "SELECT " + + "DatabasePropertyEx(@name, 'Edition') as edition, " + + "DatabasePropertyEx(@name, 'MaxSizeInBytes') as maxSizeBytes"; + //TODO: need to get current SLO here. + + builder["Database"] = db.Name; + using (var connection = CreateConnection()) + { + using (DbCommand command = connection.CreateCommand()) + { + command.CommandTimeout = connectionTimeout; + command.CommandText = commandText; + DbParameter param = command.CreateParameter(); + param.ParameterName = "@name"; + param.Value = db.Name; + command.Parameters.Add(param); + + connection.Open(); + using (DbDataReader reader = command.ExecuteReader()) + { + while (reader.Read()) + { + db.MaxSizeBytes = (long)reader["maxSizeBytes"]; + db.Edition = (string)reader["edition"]; + } + } + } + } + } + if (db.MaxSizeBytes.HasValue) + { + db.MaxSizeGB = (int)(db.MaxSizeBytes / (1024 * 1024 * 1024)); + } + + builder["Database"] = null; + } + + /// + /// Escape all the occurances of ']' in the input string. + /// + /// The input string to sanitize + /// The escaped string + string SqlEscape(string input) + { + if (string.IsNullOrEmpty(input)) + return input; + else + return input.Replace("]", "]]").Replace("'", "''"); + } + + /// + /// Alter the database properties + /// + /// The name of the database to alter + /// The new size of the database (format: ##{GB|MB}) + /// The new edition for the database + /// The new service objective name + /// The altered database + private Database AlterDatabaseProperties(string databaseName, string sizeVal, DatabaseEdition? databaseEdition, string serviceObjectiveName) + { + string commandText = + "ALTER DATABASE [{0}] MODIFY "; + + List arguments = new List(); + + if (!string.IsNullOrEmpty(sizeVal)) + { + arguments.Add(" MAXSIZE={1} "); + } + + string edition = string.Empty; + if (databaseEdition.HasValue && + databaseEdition.Value != DatabaseEdition.None) + { + arguments.Add(" EDITION='{2}' "); + edition = databaseEdition.Value.ToString(); + } + + if (!string.IsNullOrEmpty(serviceObjectiveName)) + { + arguments.Add(" SERVICE_OBJECTIVE='{3}' "); + } + + if (arguments.Count > 0) + { + commandText += " (" + string.Join(", ", arguments.ToArray()) + ")"; + } + + commandText = string.Format( + commandText, + SqlEscape(databaseName), + SqlEscape(sizeVal), + SqlEscape(edition), + SqlEscape(serviceObjectiveName)); + + builder["Database"] = null; + using (var connection = CreateConnection()) + { + using (DbCommand command = connection.CreateCommand()) + { + command.CommandTimeout = connectionTimeout; + command.CommandText = commandText; + connection.Open(); + + command.ExecuteNonQuery(); + } + } + + return GetDatabase(databaseName); + } + + /// + /// Used to alter the name of a databse. + /// + /// Current database name + /// Desired new name + /// The resultant database object + private Database AlterDatabaseName(string databaseName, string newDatabaseName) + { + string commandText = + "ALTER DATABASE [{0}] MODIFY NAME = [{1}]"; + + commandText = string.Format( + commandText, + SqlEscape(databaseName), + SqlEscape(newDatabaseName)); + + builder["Database"] = null; + using (var connection = CreateConnection()) + { + using (DbCommand command = connection.CreateCommand()) + { + command.CommandTimeout = connectionTimeout; + command.CommandText = commandText; + connection.Open(); + + command.ExecuteNonQuery(); + } + } + + return GetDatabase(newDatabaseName); + } + + /// + /// Used to load extra properties + /// + /// + public void LoadExtraProperties(object obj) + { + } + + #endregion + } +} From fc5d0e328bffd449222da0feb2ce4826c8226d0a Mon Sep 17 00:00:00 2001 From: adamkr Date: Fri, 5 Dec 2014 19:08:46 -0800 Subject: [PATCH 2/6] Addressing code review comments --- .../Database/Cmdlet/SqlAuthv12MockTests.cs | 18 +++- .../TSql/CustomAttributeProviderExtensions.cs | 87 ------------------ .../Commands.SqlDatabase.csproj | 1 + .../NewAzureSqlDatabaseServerContext.cs | 90 ++++-------------- .../Services/Server/SqlAuthContextFactory.cs | 91 +++++++++++++++++++ .../Services/Server/TSqlConnectionContext.cs | 41 ++++----- 6 files changed, 145 insertions(+), 183 deletions(-) delete mode 100644 src/ServiceManagement/Sql/Commands.SqlDatabase.Test/UnitTests/TSql/CustomAttributeProviderExtensions.cs create mode 100644 src/ServiceManagement/Sql/Commands.SqlDatabase/Services/Server/SqlAuthContextFactory.cs diff --git a/src/ServiceManagement/Sql/Commands.SqlDatabase.Test/UnitTests/Database/Cmdlet/SqlAuthv12MockTests.cs b/src/ServiceManagement/Sql/Commands.SqlDatabase.Test/UnitTests/Database/Cmdlet/SqlAuthv12MockTests.cs index 97a6ad54c7fa..456e0fa1268d 100644 --- a/src/ServiceManagement/Sql/Commands.SqlDatabase.Test/UnitTests/Database/Cmdlet/SqlAuthv12MockTests.cs +++ b/src/ServiceManagement/Sql/Commands.SqlDatabase.Test/UnitTests/Database/Cmdlet/SqlAuthv12MockTests.cs @@ -1,4 +1,18 @@ -using Microsoft.VisualStudio.TestTools.UnitTesting; +// ---------------------------------------------------------------------------------- +// +// Copyright Microsoft Corporation +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// --- + +using Microsoft.VisualStudio.TestTools.UnitTesting; using Microsoft.WindowsAzure.Commands.SqlDatabase.Services.Server; using Microsoft.WindowsAzure.Commands.SqlDatabase.Test.UnitTests.TSql; using System; @@ -35,11 +49,9 @@ public void Cleanup() [TestMethod] public void NewAzureSqlDatabaseWithSqlAuthv12() { - using (System.Management.Automation.PowerShell powershell = System.Management.Automation.PowerShell.Create()) { - // Create a context NewAzureSqlDatabaseServerContextTests.CreateServerContextSqlAuthV2( powershell, diff --git a/src/ServiceManagement/Sql/Commands.SqlDatabase.Test/UnitTests/TSql/CustomAttributeProviderExtensions.cs b/src/ServiceManagement/Sql/Commands.SqlDatabase.Test/UnitTests/TSql/CustomAttributeProviderExtensions.cs deleted file mode 100644 index 1e65d00cb5a6..000000000000 --- a/src/ServiceManagement/Sql/Commands.SqlDatabase.Test/UnitTests/TSql/CustomAttributeProviderExtensions.cs +++ /dev/null @@ -1,87 +0,0 @@ -using System; -using System.Reflection; - -namespace Microsoft.SqlServer.Management.Relational.Domain.UnitTest -{ - /// - /// Class that extends ICustomAttributeProvider - /// to allow for type-safe access to custom attributes. - /// - internal static class CustomAttributeProviderExtensions - { - /// - /// Retrieves a list of custom attributes of given type. - /// - /// Type of attribute to retrieve. - /// Object from which to request the custom attributes. - /// Array of attributes - public static T[] GetCustomAttributes(this ICustomAttributeProvider provider) where T : Attribute - { - return GetCustomAttributes(provider, false); - } - - /// - /// Retrieves a list of custom attributes of given type. - /// - /// Type of attribute to retrieve. - /// Object from which to request the custom attributes. - /// Specifies wheather the attributes can be inherited from parent object - /// Array of attributes - public static T[] GetCustomAttributes(this ICustomAttributeProvider provider, bool inherit) where T : Attribute - { - if (provider == null) - { - throw new ArgumentNullException("provider"); - } - - T[] attributes = provider.GetCustomAttributes(typeof(T), inherit) as T[]; - if (attributes == null) - { - return new T[0]; - } - - return attributes; - } - - /// - /// Retrieves a single custom attribute of given type. - /// - /// Type of attribute to retrieve. - /// Object from which to request the custom attributes. - /// An attribute obtained or null. - public static T GetCustomAttribute(this ICustomAttributeProvider provider) where T : Attribute - { - return GetCustomAttribute(provider, false); - } - - /// - /// Retrieves a single custom attribute of given type. - /// - /// Type of attribute to retrieve. - /// Object from which to request the custom attributes. - /// Specifies wheather the attributes can be inherited from parent object - /// An attribute obtained or null. - public static T GetCustomAttribute(this ICustomAttributeProvider provider, bool inherit) where T : Attribute - { - T[] attributes = GetCustomAttributes(provider, inherit); - - if (attributes.Length > 1) - { - throw new InvalidOperationException( - string.Format( - "Domain element is expected to contain 1 attribute(s) of type [{1}], but it contains {0} attribute(s).", - attributes.Length, - typeof(T).Name)); - } - - if (attributes.Length == 1) - { - return attributes[0]; - } - else - { - return null; - } - } - } -} \ No newline at end of file diff --git a/src/ServiceManagement/Sql/Commands.SqlDatabase/Commands.SqlDatabase.csproj b/src/ServiceManagement/Sql/Commands.SqlDatabase/Commands.SqlDatabase.csproj index e33ac6d52db3..9b3a8453e27e 100644 --- a/src/ServiceManagement/Sql/Commands.SqlDatabase/Commands.SqlDatabase.csproj +++ b/src/ServiceManagement/Sql/Commands.SqlDatabase/Commands.SqlDatabase.csproj @@ -190,6 +190,7 @@ + diff --git a/src/ServiceManagement/Sql/Commands.SqlDatabase/Database/Cmdlet/NewAzureSqlDatabaseServerContext.cs b/src/ServiceManagement/Sql/Commands.SqlDatabase/Database/Cmdlet/NewAzureSqlDatabaseServerContext.cs index f7b5a0ec8abf..bcf2cffbe686 100644 --- a/src/ServiceManagement/Sql/Commands.SqlDatabase/Database/Cmdlet/NewAzureSqlDatabaseServerContext.cs +++ b/src/ServiceManagement/Sql/Commands.SqlDatabase/Database/Cmdlet/NewAzureSqlDatabaseServerContext.cs @@ -89,10 +89,10 @@ public class NewAzureSqlDatabaseServerContext : AzurePSCmdlet /// /// Gets or sets the management site data connection fully qualified server name. /// - [Parameter(Mandatory = true, Position = 0, + [Parameter(Mandatory = true, Position = 0, ParameterSetName = FullyQualifiedServerNameWithSqlAuthParamSet, HelpMessage = "The fully qualified server name")] - [Parameter(Mandatory = true, Position = 0, + [Parameter(Mandatory = true, Position = 0, ParameterSetName = FullyQualifiedServerNameWithCertAuthParamSet, HelpMessage = "The fully qualified server name")] [ValidateNotNull] @@ -109,13 +109,13 @@ public class NewAzureSqlDatabaseServerContext : AzurePSCmdlet /// /// Gets or sets the server credentials /// - [Parameter(Mandatory = true, Position = 1, + [Parameter(Mandatory = true, Position = 1, ParameterSetName = ServerNameWithSqlAuthParamSet, HelpMessage = "The credentials for the server")] - [Parameter(Mandatory = true, Position = 1, + [Parameter(Mandatory = true, Position = 1, ParameterSetName = FullyQualifiedServerNameWithSqlAuthParamSet, HelpMessage = "The credentials for the server")] - [Parameter(Mandatory = true, Position = 1, + [Parameter(Mandatory = true, Position = 1, ParameterSetName = ManageUrlWithSqlAuthParamSet, HelpMessage = "The credentials for the server")] [ValidateNotNull] @@ -124,10 +124,10 @@ public class NewAzureSqlDatabaseServerContext : AzurePSCmdlet /// /// Gets or sets whether or not the current subscription should be used for authentication /// - [Parameter(Mandatory = true, Position = 1, + [Parameter(Mandatory = true, Position = 1, ParameterSetName = ServerNameWithCertAuthParamSet, HelpMessage = "Use certificate authentication")] - [Parameter(Mandatory = true, Position = 1, + [Parameter(Mandatory = true, Position = 1, ParameterSetName = FullyQualifiedServerNameWithCertAuthParamSet, HelpMessage = "Use certificate authentication")] public SwitchParameter UseSubscription { get; set; } @@ -140,13 +140,6 @@ public class NewAzureSqlDatabaseServerContext : AzurePSCmdlet HelpMessage = "The subscription to use, or uses the current subscription if not specified")] public string SubscriptionName { get; set; } - /// - /// Switch to indiciate the the server is an ESA server - /// - [Parameter(Mandatory = false, - HelpMessage = "Indicates the server version being targeted. Valid values [2.0, 12.0]. Default = 2.0")] - public float Version { get; set; } - #endregion #region Current Subscription Management @@ -188,63 +181,20 @@ internal IServerDataServiceContext GetServerDataServiceBySqlAuth( IServerDataServiceContext context = null; Guid sessionActivityId = Guid.NewGuid(); - if (this.MyInvocation.BoundParameters.ContainsKey("Version")) + try { - if (this.Version == 12.0f) - { - try - { - context = new TSqlConnectionContext( - sessionActivityId, - manageUrl.Host, - credentials.UserName, - credentials.Password); - } - catch (Exception ex) - { - SqlDatabaseExceptionHandler.WriteErrorDetails( - this, - sessionActivityId.ToString(), - ex); - - // The context is not in an valid state because of the error, set the context - // back to null. - context = null; - } - } - else - { - try - { - context = ServerDataServiceSqlAuth.Create( - managementServiceUri, - sessionActivityId, - credentials, - serverName); - - // Retrieve $metadata to verify model version compatibility - XDocument metadata = ((ServerDataServiceSqlAuth)context).RetrieveMetadata(); - XDocument filteredMetadata = DataConnectionUtility.FilterMetadataDocument(metadata); - string metadataHash = DataConnectionUtility.GetDocumentHash(filteredMetadata); - if (!((ServerDataServiceSqlAuth)context).metadataHashes.Any(knownHash => metadataHash == knownHash)) - { - this.WriteWarning(Resources.WarningModelOutOfDate); - } - - ((ServerDataServiceSqlAuth)context).MergeOption = MergeOption.PreserveChanges; - } - catch (Exception ex) - { - SqlDatabaseExceptionHandler.WriteErrorDetails( - this, - sessionActivityId.ToString(), - ex); - - // The context is not in an valid state because of the error, set the context - // back to null. - context = null; - } - } + context = SqlAuthContextFactory.GetContext(this, serverName, manageUrl, credentials, sessionActivityId, managementServiceUri); + } + catch (Exception ex) + { + SqlDatabaseExceptionHandler.WriteErrorDetails( + this, + sessionActivityId.ToString(), + ex); + + // The context is not in an valid state because of the error, set the context + // back to null. + context = null; } return context; diff --git a/src/ServiceManagement/Sql/Commands.SqlDatabase/Services/Server/SqlAuthContextFactory.cs b/src/ServiceManagement/Sql/Commands.SqlDatabase/Services/Server/SqlAuthContextFactory.cs new file mode 100644 index 000000000000..0cc49e16ee9d --- /dev/null +++ b/src/ServiceManagement/Sql/Commands.SqlDatabase/Services/Server/SqlAuthContextFactory.cs @@ -0,0 +1,91 @@ +using Microsoft.WindowsAzure.Commands.SqlDatabase.Properties; +using Microsoft.WindowsAzure.Commands.SqlDatabase.Services.Common; +using System; +using System.Collections.Generic; +using System.Data.Services.Client; +using System.Data.SqlClient; +using System.Linq; +using System.Management.Automation; +using System.Text; +using System.Threading.Tasks; +using System.Xml.Linq; + +namespace Microsoft.WindowsAzure.Commands.SqlDatabase.Services.Server +{ + public class SqlAuthContextFactory + { + public static IServerDataServiceContext GetContext( + PSCmdlet cmdlet, + string serverName, + Uri manageUrl, + SqlAuthenticationCredentials credentials, + Guid sessionActivityId, + Uri managementServiceUri) + { + Version version = GetVersion(manageUrl, credentials); + + IServerDataServiceContext context = null; + + if (version.Major >= 12) + { + context = new TSqlConnectionContext( + sessionActivityId, + manageUrl.Host, + credentials.UserName, + credentials.Password); + } + else + { + context = ServerDataServiceSqlAuth.Create( + managementServiceUri, + sessionActivityId, + credentials, + serverName); + + // Retrieve $metadata to verify model version compatibility + XDocument metadata = ((ServerDataServiceSqlAuth)context).RetrieveMetadata(); + XDocument filteredMetadata = DataConnectionUtility.FilterMetadataDocument(metadata); + string metadataHash = DataConnectionUtility.GetDocumentHash(filteredMetadata); + if (!((ServerDataServiceSqlAuth)context).metadataHashes.Any(knownHash => metadataHash == knownHash)) + { + cmdlet.WriteWarning(Resources.WarningModelOutOfDate); + } + + ((ServerDataServiceSqlAuth)context).MergeOption = MergeOption.PreserveChanges; + } + + return context; + } + + /// + /// Queries the server to get the server version + /// + /// The manage url of the server. Eg: https://{serverName}.database.windows.net + /// The login credentials + /// The server version + private static Version GetVersion(Uri manageUrl, SqlAuthenticationCredentials credentials) + { + string serverName = manageUrl.Host.Split('.').First(); + SqlConnectionStringBuilder builder = new SqlConnectionStringBuilder(); + builder["Server"] = manageUrl.Host; + builder.UserID = credentials.UserName + "@" + serverName; + builder.Password = credentials.Password; + builder["Database"] = null; + builder["Encrypt"] = false; + builder.ConnectTimeout = 60; + + string commandText = "select serverproperty('ProductVersion')"; + + using(SqlConnection conn = new SqlConnection(builder.ConnectionString)) + { + using (SqlCommand command = new SqlCommand(commandText, conn)) + { + conn.Open(); + + string val = (string)command.ExecuteScalar(); + return new Version(val); + } + } + } + } +} diff --git a/src/ServiceManagement/Sql/Commands.SqlDatabase/Services/Server/TSqlConnectionContext.cs b/src/ServiceManagement/Sql/Commands.SqlDatabase/Services/Server/TSqlConnectionContext.cs index b7cb652ad20f..6c84fc99ce98 100644 --- a/src/ServiceManagement/Sql/Commands.SqlDatabase/Services/Server/TSqlConnectionContext.cs +++ b/src/ServiceManagement/Sql/Commands.SqlDatabase/Services/Server/TSqlConnectionContext.cs @@ -27,7 +27,12 @@ public class TSqlConnectionContext : IServerDataServiceContext /// /// Timeout duration for commands /// - private static int connectionTimeout = 60; + private static int commandTimeout = 300; + + /// + /// Timeout duration for connections + /// + private static int connectionTimeout = 30; /// /// Set this to override the SQL Connection with a mock version @@ -248,7 +253,6 @@ public Database GetDatabase(string databaseName) using (var connection = CreateConnection()) { using (DbCommand command = connection.CreateCommand()) - //using (SqlCommand command = new SqlCommand(getDatabaseQuery, connection)) { command.CommandTimeout = connectionTimeout; command.CommandText = getDatabaseQuery; @@ -259,7 +263,6 @@ public Database GetDatabase(string databaseName) connection.Open(); - //using (SqlDataReader reader = command.ExecuteReader()) using (DbDataReader reader = command.ExecuteReader()) { if (reader.HasRows) @@ -277,15 +280,6 @@ public Database GetDatabase(string databaseName) } } - //db.ServiceObjectiveName; - //db.serviceobjectiveid - //db.SLOAssignment* - - //db.Server; - //db.ServiceObjective; - //db.DatabaseCopies; - //db.DatabaseMetrics; - GetDatabaseProperties(db); db.ServiceObjective = GetServiceObjective(db.ServiceObjectiveName); db.Context = this; @@ -399,7 +393,7 @@ public Database CreateNewDatabase( { using (DbCommand command = connection.CreateCommand()) { - command.CommandTimeout = connectionTimeout; + command.CommandTimeout = commandTimeout; command.CommandText = commandText; connection.Open(); @@ -513,6 +507,7 @@ public void RemoveDatabase(string databaseName) { connection.Open(); command.CommandText = commandText; + command.CommandTimeout = commandTimeout; command.ExecuteNonQuery(); } @@ -637,37 +632,37 @@ public DatabaseOperation[] GetDatabasesOperations() public Model.DatabaseCopy[] GetDatabaseCopy(string databaseName, string partnerServer, string partnerDatabaseName) { - throw new NotImplementedException(); + throw new NotSupportedException(); } public Model.DatabaseCopy GetDatabaseCopy(Model.DatabaseCopy databaseCopy) { - throw new NotImplementedException(); + throw new NotSupportedException(); } public Model.DatabaseCopy StartDatabaseCopy(string databaseName, string partnerServer, string partnerDatabaseName, bool continuousCopy, bool isOfflineSecondary) { - throw new NotImplementedException(); + throw new NotSupportedException(); } public void StopDatabaseCopy(Model.DatabaseCopy databaseCopy, bool forcedTermination) { - throw new NotImplementedException(); + throw new NotSupportedException(); } public RestorableDroppedDatabase[] GetRestorableDroppedDatabases() { - throw new NotImplementedException(); + throw new NotSupportedException(); } public RestorableDroppedDatabase GetRestorableDroppedDatabase(string databaseName, DateTime deletionDate) { - throw new NotImplementedException(); + throw new NotSupportedException(); } public RestoreDatabaseOperation RestoreDatabase(string sourceDatabaseName, DateTime? sourceDatabaseDeletionDate, string targetServerName, string targetDatabaseName, DateTime? pointInTime) { - throw new NotImplementedException(); + throw new NotSupportedException(); } #region Helpers @@ -740,7 +735,7 @@ private void GetDatabaseProperties(Database db) { using (DbCommand command = connection.CreateCommand()) { - command.CommandTimeout = connectionTimeout; + command.CommandTimeout = commandTimeout; command.CommandText = commandText; DbParameter param = command.CreateParameter(); param.ParameterName = "@name"; @@ -830,7 +825,7 @@ private Database AlterDatabaseProperties(string databaseName, string sizeVal, Da { using (DbCommand command = connection.CreateCommand()) { - command.CommandTimeout = connectionTimeout; + command.CommandTimeout = commandTimeout; command.CommandText = commandText; connection.Open(); @@ -862,7 +857,7 @@ private Database AlterDatabaseName(string databaseName, string newDatabaseName) { using (DbCommand command = connection.CreateCommand()) { - command.CommandTimeout = connectionTimeout; + command.CommandTimeout = commandTimeout; command.CommandText = commandText; connection.Open(); From 5f450fe8f479239b70ae93800e9d77fe9452ce82 Mon Sep 17 00:00:00 2001 From: adamkr Date: Fri, 5 Dec 2014 19:12:24 -0800 Subject: [PATCH 3/6] Adding a test playback file. --- .../TSqlMockSessions/SqlAuthv12MockTests.xml | 111 ++++++++++++++++++ 1 file changed, 111 insertions(+) create mode 100644 src/ServiceManagement/Sql/Commands.SqlDatabase.Test/Resources/TSqlMockSessions/SqlAuthv12MockTests.xml diff --git a/src/ServiceManagement/Sql/Commands.SqlDatabase.Test/Resources/TSqlMockSessions/SqlAuthv12MockTests.xml b/src/ServiceManagement/Sql/Commands.SqlDatabase.Test/Resources/TSqlMockSessions/SqlAuthv12MockTests.xml new file mode 100644 index 000000000000..77d6d831521f --- /dev/null +++ b/src/ServiceManagement/Sql/Commands.SqlDatabase.Test/Resources/TSqlMockSessions/SqlAuthv12MockTests.xml @@ -0,0 +1,111 @@ + + + + SqlAuthv12MockTests + master + + SELECT + [db].[name], + [db].[database_id], + [db].[create_date], + [db].[collation_name], + [db].[is_read_only], + [db].[is_query_store_on], + [db].[is_recursive_triggers_on], + [db].[is_federation_member], + CONVERT (bit, CASE WHEN [db].[name] in ('master') THEN 1 ELSE [db].[is_distributor] END) AS [is_system_object], + CONVERT (int, + CASE + WHEN [db].[is_in_standby] = 1 THEN 0x0040 + ELSE 0 + END | + CASE + WHEN [db].[is_cleanly_shutdown] = 1 THEN 0x0080 + ELSE 0 + END | + CASE [db].[state] + WHEN 0 THEN 0x0001 -- NORMAL + WHEN 1 THEN 0x0002 -- RESTORING + WHEN 2 THEN 0x0008 -- RECOVERING + WHEN 3 THEN 0x0004 -- RECOVERY_PENDING + WHEN 4 THEN 0x0010 -- SUSPECT + WHEN 5 THEN 0x0100 -- EMERGENCY + WHEN 6 THEN 0x0020 -- OFFLINE + WHEN 7 THEN 0x0400 -- COPYING + WHEN 9 THEN 0x0800 -- CREATING + WHEN 10 THEN 0x1000 -- OFFLINE_SECONDARY + ELSE 0x0010 -- SUSPECT + END) AS [status] + FROM [sys].[databases] AS [db] + WHERE ([db].[name] = testdb4 OR testdb4 IS NULL) + + + + + + + + + + + + + + + + + + + + + + + + + + + + testdb4 + 8 + 2014-12-05T15:57:43.137-08:00 + SQL_Latin1_General_CP1_CI_AS + false + false + false + false + false + 1 +
+
+
+
+ + SqlAuthv12MockTests + testdb4 + SELECT DatabasePropertyEx(testdb4, 'Edition') as edition, DatabasePropertyEx(testdb4, 'MaxSizeInBytes') as maxSizeBytes + + + + + + + + + + + + + + + + + + + + Standard + 268435456000 +
+
+
+
+
From 6f3380c58242eb5dfa72b3b1b9205e269b5c0749 Mon Sep 17 00:00:00 2001 From: adamkr Date: Sat, 6 Dec 2014 14:48:35 -0800 Subject: [PATCH 4/6] Adding additional tests. --- .../TSqlMockSessions/SqlAuthv12MockTests.xml | 1113 ++++++++++++++++- .../NewAzureSqlDatabaseServerContextTests.cs | 3 +- .../Database/Cmdlet/SqlAuthv12MockTests.cs | 174 ++- .../UnitTests/TSql/MockSettings.cs | 33 +- .../UnitTests/TSql/MockSqlCommand.cs | 68 +- .../UnitTests/TSql/MockSqlConnection.cs | 18 - .../Services/Server/TSqlConnectionContext.cs | 3 +- 7 files changed, 1321 insertions(+), 91 deletions(-) diff --git a/src/ServiceManagement/Sql/Commands.SqlDatabase.Test/Resources/TSqlMockSessions/SqlAuthv12MockTests.xml b/src/ServiceManagement/Sql/Commands.SqlDatabase.Test/Resources/TSqlMockSessions/SqlAuthv12MockTests.xml index 77d6d831521f..02e4c3036230 100644 --- a/src/ServiceManagement/Sql/Commands.SqlDatabase.Test/Resources/TSqlMockSessions/SqlAuthv12MockTests.xml +++ b/src/ServiceManagement/Sql/Commands.SqlDatabase.Test/Resources/TSqlMockSessions/SqlAuthv12MockTests.xml @@ -1,7 +1,7 @@ - SqlAuthv12MockTests + SqlAuthv12MockTests.GetAzureSqlDatabaseWithSqlAuthv12 master SELECT @@ -37,7 +37,115 @@ ELSE 0x0010 -- SUSPECT END) AS [status] FROM [sys].[databases] AS [db] - WHERE ([db].[name] = testdb4 OR testdb4 IS NULL) + WHERE ([db].[name] = testdb1 OR testdb1 IS NULL) + + + + + + + + + + + + + + + + + + + + + + + + + + + + testdb1 + 5 + 2014-12-05T15:52:35.953-08:00 + SQL_Latin1_General_CP1_CI_AS + false + false + false + false + false + 1 +
+
+
+
+ + SqlAuthv12MockTests.GetAzureSqlDatabaseWithSqlAuthv12 + testdb1 + SELECT DatabasePropertyEx(testdb1, 'Edition') as edition, DatabasePropertyEx(testdb1, 'MaxSizeInBytes') as maxSizeBytes + + + + + + + + + + + + + + + + + + + + Standard + 268435456000 +
+
+
+
+ + SqlAuthv12MockTests.GetAzureSqlDatabaseWithSqlAuthv12 + master + + SELECT + [db].[name], + [db].[database_id], + [db].[create_date], + [db].[collation_name], + [db].[is_read_only], + [db].[is_query_store_on], + [db].[is_recursive_triggers_on], + [db].[is_federation_member], + CONVERT (bit, CASE WHEN [db].[name] in ('master') THEN 1 ELSE [db].[is_distributor] END) AS [is_system_object], + CONVERT (int, + CASE + WHEN [db].[is_in_standby] = 1 THEN 0x0040 + ELSE 0 + END | + CASE + WHEN [db].[is_cleanly_shutdown] = 1 THEN 0x0080 + ELSE 0 + END | + CASE [db].[state] + WHEN 0 THEN 0x0001 -- NORMAL + WHEN 1 THEN 0x0002 -- RESTORING + WHEN 2 THEN 0x0008 -- RECOVERING + WHEN 3 THEN 0x0004 -- RECOVERY_PENDING + WHEN 4 THEN 0x0010 -- SUSPECT + WHEN 5 THEN 0x0100 -- EMERGENCY + WHEN 6 THEN 0x0020 -- OFFLINE + WHEN 7 THEN 0x0400 -- COPYING + WHEN 9 THEN 0x0800 -- CREATING + WHEN 10 THEN 0x1000 -- OFFLINE_SECONDARY + ELSE 0x0010 -- SUSPECT + END) AS [status] + FROM [sys].[databases] AS [db] + WHERE ([db].[name] = OR IS NULL) @@ -64,6 +172,54 @@ + + master + 1 + 2014-12-05T12:35:00.73-08:00 + SQL_Latin1_General_CP1_CI_AS + false + false + false + false + true + 1 +
+ + testdb1 + 5 + 2014-12-05T15:52:35.953-08:00 + SQL_Latin1_General_CP1_CI_AS + false + false + false + false + false + 1 +
+ + testdb2 + 6 + 2014-12-05T15:54:56.603-08:00 + Japanese_CI_AS + false + false + false + false + false + 1 +
+ + testdb3 + 7 + 2014-12-05T15:55:29.173-08:00 + SQL_Latin1_General_CP1_CI_AS + false + false + false + false + false + 1 +
testdb48 @@ -80,7 +236,94 @@ - SqlAuthv12MockTests + SqlAuthv12MockTests.GetAzureSqlDatabaseWithSqlAuthv12 + master + SELECT DatabasePropertyEx(master, 'Edition') as edition, DatabasePropertyEx(master, 'MaxSizeInBytes') as maxSizeBytes + + + + + + + + + + + + + + + + + + +
+ System + 5368709120 +
+
+
+
+ + SqlAuthv12MockTests.GetAzureSqlDatabaseWithSqlAuthv12 + testdb2 + SELECT DatabasePropertyEx(testdb2, 'Edition') as edition, DatabasePropertyEx(testdb2, 'MaxSizeInBytes') as maxSizeBytes + + + + + + + + + + + + + + + + + + + + Basic + 2147483648 +
+
+
+
+ + SqlAuthv12MockTests.GetAzureSqlDatabaseWithSqlAuthv12 + testdb3 + SELECT DatabasePropertyEx(testdb3, 'Edition') as edition, DatabasePropertyEx(testdb3, 'MaxSizeInBytes') as maxSizeBytes + + + + + + + + + + + + + + + + + + + + Standard + 107374182400 +
+
+
+
+ + SqlAuthv12MockTests.GetAzureSqlDatabaseWithSqlAuthv12 testdb4 SELECT DatabasePropertyEx(testdb4, 'Edition') as edition, DatabasePropertyEx(testdb4, 'MaxSizeInBytes') as maxSizeBytes @@ -108,4 +351,868 @@ + + SqlAuthv12MockTests.NewAzureSqlDatabaseWithSqlAuthv12 + master + + SELECT + [db].[name], + [db].[database_id], + [db].[create_date], + [db].[collation_name], + [db].[is_read_only], + [db].[is_query_store_on], + [db].[is_recursive_triggers_on], + [db].[is_federation_member], + CONVERT (bit, CASE WHEN [db].[name] in ('master') THEN 1 ELSE [db].[is_distributor] END) AS [is_system_object], + CONVERT (int, + CASE + WHEN [db].[is_in_standby] = 1 THEN 0x0040 + ELSE 0 + END | + CASE + WHEN [db].[is_cleanly_shutdown] = 1 THEN 0x0080 + ELSE 0 + END | + CASE [db].[state] + WHEN 0 THEN 0x0001 -- NORMAL + WHEN 1 THEN 0x0002 -- RESTORING + WHEN 2 THEN 0x0008 -- RECOVERING + WHEN 3 THEN 0x0004 -- RECOVERY_PENDING + WHEN 4 THEN 0x0010 -- SUSPECT + WHEN 5 THEN 0x0100 -- EMERGENCY + WHEN 6 THEN 0x0020 -- OFFLINE + WHEN 7 THEN 0x0400 -- COPYING + WHEN 9 THEN 0x0800 -- CREATING + WHEN 10 THEN 0x1000 -- OFFLINE_SECONDARY + ELSE 0x0010 -- SUSPECT + END) AS [status] + FROM [sys].[databases] AS [db] + WHERE ([db].[name] = testdb1 OR testdb1 IS NULL) + + + + + + + + + + + + + + + + + + + + + + + + + + + + testdb1 + 5 + 2014-12-06T12:25:52.49-08:00 + SQL_Latin1_General_CP1_CI_AS + false + false + false + false + false + 1 +
+
+
+
+ + SqlAuthv12MockTests.NewAzureSqlDatabaseWithSqlAuthv12 + testdb1 + SELECT DatabasePropertyEx(testdb1, 'Edition') as edition, DatabasePropertyEx(testdb1, 'MaxSizeInBytes') as maxSizeBytes + + + + + + + + + + + + + + + + + + + + Standard + 268435456000 +
+
+
+
+ + SqlAuthv12MockTests.NewAzureSqlDatabaseWithSqlAuthv12 + master + + SELECT + [db].[name], + [db].[database_id], + [db].[create_date], + [db].[collation_name], + [db].[is_read_only], + [db].[is_query_store_on], + [db].[is_recursive_triggers_on], + [db].[is_federation_member], + CONVERT (bit, CASE WHEN [db].[name] in ('master') THEN 1 ELSE [db].[is_distributor] END) AS [is_system_object], + CONVERT (int, + CASE + WHEN [db].[is_in_standby] = 1 THEN 0x0040 + ELSE 0 + END | + CASE + WHEN [db].[is_cleanly_shutdown] = 1 THEN 0x0080 + ELSE 0 + END | + CASE [db].[state] + WHEN 0 THEN 0x0001 -- NORMAL + WHEN 1 THEN 0x0002 -- RESTORING + WHEN 2 THEN 0x0008 -- RECOVERING + WHEN 3 THEN 0x0004 -- RECOVERY_PENDING + WHEN 4 THEN 0x0010 -- SUSPECT + WHEN 5 THEN 0x0100 -- EMERGENCY + WHEN 6 THEN 0x0020 -- OFFLINE + WHEN 7 THEN 0x0400 -- COPYING + WHEN 9 THEN 0x0800 -- CREATING + WHEN 10 THEN 0x1000 -- OFFLINE_SECONDARY + ELSE 0x0010 -- SUSPECT + END) AS [status] + FROM [sys].[databases] AS [db] + WHERE ([db].[name] = testdb2 OR testdb2 IS NULL) + + + + + + + + + + + + + + + + + + + + + + + + + + + + testdb2 + 6 + 2014-12-06T12:27:55.997-08:00 + Japanese_CI_AS + false + false + false + false + false + 1 +
+
+
+
+ + SqlAuthv12MockTests.NewAzureSqlDatabaseWithSqlAuthv12 + testdb2 + SELECT DatabasePropertyEx(testdb2, 'Edition') as edition, DatabasePropertyEx(testdb2, 'MaxSizeInBytes') as maxSizeBytes + + + + + + + + + + + + + + + + + + + + Basic + 2147483648 +
+
+
+
+ + SqlAuthv12MockTests.NewAzureSqlDatabaseWithSqlAuthv12 + master + + SELECT + [db].[name], + [db].[database_id], + [db].[create_date], + [db].[collation_name], + [db].[is_read_only], + [db].[is_query_store_on], + [db].[is_recursive_triggers_on], + [db].[is_federation_member], + CONVERT (bit, CASE WHEN [db].[name] in ('master') THEN 1 ELSE [db].[is_distributor] END) AS [is_system_object], + CONVERT (int, + CASE + WHEN [db].[is_in_standby] = 1 THEN 0x0040 + ELSE 0 + END | + CASE + WHEN [db].[is_cleanly_shutdown] = 1 THEN 0x0080 + ELSE 0 + END | + CASE [db].[state] + WHEN 0 THEN 0x0001 -- NORMAL + WHEN 1 THEN 0x0002 -- RESTORING + WHEN 2 THEN 0x0008 -- RECOVERING + WHEN 3 THEN 0x0004 -- RECOVERY_PENDING + WHEN 4 THEN 0x0010 -- SUSPECT + WHEN 5 THEN 0x0100 -- EMERGENCY + WHEN 6 THEN 0x0020 -- OFFLINE + WHEN 7 THEN 0x0400 -- COPYING + WHEN 9 THEN 0x0800 -- CREATING + WHEN 10 THEN 0x1000 -- OFFLINE_SECONDARY + ELSE 0x0010 -- SUSPECT + END) AS [status] + FROM [sys].[databases] AS [db] + WHERE ([db].[name] = testdb3 OR testdb3 IS NULL) + + + + + + + + + + + + + + + + + + + + + + + + + + + + testdb3 + 7 + 2014-12-06T12:28:28.81-08:00 + SQL_Latin1_General_CP1_CI_AS + false + false + false + false + false + 1 +
+
+
+
+ + SqlAuthv12MockTests.NewAzureSqlDatabaseWithSqlAuthv12 + testdb3 + SELECT DatabasePropertyEx(testdb3, 'Edition') as edition, DatabasePropertyEx(testdb3, 'MaxSizeInBytes') as maxSizeBytes + + + + + + + + + + + + + + + + + + + + Standard + 107374182400 +
+
+
+
+ + SqlAuthv12MockTests.NewAzureSqlDatabaseWithSqlAuthv12 + master + + SELECT + [db].[name], + [db].[database_id], + [db].[create_date], + [db].[collation_name], + [db].[is_read_only], + [db].[is_query_store_on], + [db].[is_recursive_triggers_on], + [db].[is_federation_member], + CONVERT (bit, CASE WHEN [db].[name] in ('master') THEN 1 ELSE [db].[is_distributor] END) AS [is_system_object], + CONVERT (int, + CASE + WHEN [db].[is_in_standby] = 1 THEN 0x0040 + ELSE 0 + END | + CASE + WHEN [db].[is_cleanly_shutdown] = 1 THEN 0x0080 + ELSE 0 + END | + CASE [db].[state] + WHEN 0 THEN 0x0001 -- NORMAL + WHEN 1 THEN 0x0002 -- RESTORING + WHEN 2 THEN 0x0008 -- RECOVERING + WHEN 3 THEN 0x0004 -- RECOVERY_PENDING + WHEN 4 THEN 0x0010 -- SUSPECT + WHEN 5 THEN 0x0100 -- EMERGENCY + WHEN 6 THEN 0x0020 -- OFFLINE + WHEN 7 THEN 0x0400 -- COPYING + WHEN 9 THEN 0x0800 -- CREATING + WHEN 10 THEN 0x1000 -- OFFLINE_SECONDARY + ELSE 0x0010 -- SUSPECT + END) AS [status] + FROM [sys].[databases] AS [db] + WHERE ([db].[name] = testdb4 OR testdb4 IS NULL) + + + + + + + + + + + + + + + + + + + + + + + + + + + + testdb4 + 8 + 2014-12-06T12:31:55.247-08:00 + SQL_Latin1_General_CP1_CI_AS + false + false + false + false + false + 1 +
+
+
+
+ + SqlAuthv12MockTests.NewAzureSqlDatabaseWithSqlAuthv12 + testdb4 + SELECT DatabasePropertyEx(testdb4, 'Edition') as edition, DatabasePropertyEx(testdb4, 'MaxSizeInBytes') as maxSizeBytes + + + + + + + + + + + + + + + + + + + + Standard + 268435456000 +
+
+
+
+ + SqlAuthv12MockTests.SetAzureSqlDatabaseWithSqlAuthv12 + master + + SELECT + [db].[name], + [db].[database_id], + [db].[create_date], + [db].[collation_name], + [db].[is_read_only], + [db].[is_query_store_on], + [db].[is_recursive_triggers_on], + [db].[is_federation_member], + CONVERT (bit, CASE WHEN [db].[name] in ('master') THEN 1 ELSE [db].[is_distributor] END) AS [is_system_object], + CONVERT (int, + CASE + WHEN [db].[is_in_standby] = 1 THEN 0x0040 + ELSE 0 + END | + CASE + WHEN [db].[is_cleanly_shutdown] = 1 THEN 0x0080 + ELSE 0 + END | + CASE [db].[state] + WHEN 0 THEN 0x0001 -- NORMAL + WHEN 1 THEN 0x0002 -- RESTORING + WHEN 2 THEN 0x0008 -- RECOVERING + WHEN 3 THEN 0x0004 -- RECOVERY_PENDING + WHEN 4 THEN 0x0010 -- SUSPECT + WHEN 5 THEN 0x0100 -- EMERGENCY + WHEN 6 THEN 0x0020 -- OFFLINE + WHEN 7 THEN 0x0400 -- COPYING + WHEN 9 THEN 0x0800 -- CREATING + WHEN 10 THEN 0x1000 -- OFFLINE_SECONDARY + ELSE 0x0010 -- SUSPECT + END) AS [status] + FROM [sys].[databases] AS [db] + WHERE ([db].[name] = testdb1 OR testdb1 IS NULL) + + + + + + + + + + + + + + + + + + + + + + + + + + + + testdb1 + 5 + 2014-12-06T13:39:04.633-08:00 + SQL_Latin1_General_CP1_CI_AS + false + false + false + false + false + 1 +
+
+
+
+ + SqlAuthv12MockTests.SetAzureSqlDatabaseWithSqlAuthv12 + testdb1 + SELECT DatabasePropertyEx(testdb1, 'Edition') as edition, DatabasePropertyEx(testdb1, 'MaxSizeInBytes') as maxSizeBytes + + + + + + + + + + + + + + + + + + + + Basic + 1073741824 +
+
+
+
+ + SqlAuthv12MockTests.SetAzureSqlDatabaseWithSqlAuthv12 + master + + SELECT + [db].[name], + [db].[database_id], + [db].[create_date], + [db].[collation_name], + [db].[is_read_only], + [db].[is_query_store_on], + [db].[is_recursive_triggers_on], + [db].[is_federation_member], + CONVERT (bit, CASE WHEN [db].[name] in ('master') THEN 1 ELSE [db].[is_distributor] END) AS [is_system_object], + CONVERT (int, + CASE + WHEN [db].[is_in_standby] = 1 THEN 0x0040 + ELSE 0 + END | + CASE + WHEN [db].[is_cleanly_shutdown] = 1 THEN 0x0080 + ELSE 0 + END | + CASE [db].[state] + WHEN 0 THEN 0x0001 -- NORMAL + WHEN 1 THEN 0x0002 -- RESTORING + WHEN 2 THEN 0x0008 -- RECOVERING + WHEN 3 THEN 0x0004 -- RECOVERY_PENDING + WHEN 4 THEN 0x0010 -- SUSPECT + WHEN 5 THEN 0x0100 -- EMERGENCY + WHEN 6 THEN 0x0020 -- OFFLINE + WHEN 7 THEN 0x0400 -- COPYING + WHEN 9 THEN 0x0800 -- CREATING + WHEN 10 THEN 0x1000 -- OFFLINE_SECONDARY + ELSE 0x0010 -- SUSPECT + END) AS [status] + FROM [sys].[databases] AS [db] + WHERE ([db].[name] = testdb2 OR testdb2 IS NULL) + + + + + + + + + + + + + + + + + + + + + + + + + + + + testdb2 + 6 + 2014-12-06T13:50:00.72-08:00 + Japanese_CI_AS + false + false + false + false + false + 1 +
+
+
+
+ + SqlAuthv12MockTests.SetAzureSqlDatabaseWithSqlAuthv12 + testdb2 + SELECT DatabasePropertyEx(testdb2, 'Edition') as edition, DatabasePropertyEx(testdb2, 'MaxSizeInBytes') as maxSizeBytes + + + + + + + + + + + + + + + + + + + + Standard + 107374182400 +
+
+
+
+ + SqlAuthv12MockTests.SetAzureSqlDatabaseWithSqlAuthv12 + master + + SELECT + [db].[name], + [db].[database_id], + [db].[create_date], + [db].[collation_name], + [db].[is_read_only], + [db].[is_query_store_on], + [db].[is_recursive_triggers_on], + [db].[is_federation_member], + CONVERT (bit, CASE WHEN [db].[name] in ('master') THEN 1 ELSE [db].[is_distributor] END) AS [is_system_object], + CONVERT (int, + CASE + WHEN [db].[is_in_standby] = 1 THEN 0x0040 + ELSE 0 + END | + CASE + WHEN [db].[is_cleanly_shutdown] = 1 THEN 0x0080 + ELSE 0 + END | + CASE [db].[state] + WHEN 0 THEN 0x0001 -- NORMAL + WHEN 1 THEN 0x0002 -- RESTORING + WHEN 2 THEN 0x0008 -- RECOVERING + WHEN 3 THEN 0x0004 -- RECOVERY_PENDING + WHEN 4 THEN 0x0010 -- SUSPECT + WHEN 5 THEN 0x0100 -- EMERGENCY + WHEN 6 THEN 0x0020 -- OFFLINE + WHEN 7 THEN 0x0400 -- COPYING + WHEN 9 THEN 0x0800 -- CREATING + WHEN 10 THEN 0x1000 -- OFFLINE_SECONDARY + ELSE 0x0010 -- SUSPECT + END) AS [status] + FROM [sys].[databases] AS [db] + WHERE ([db].[name] = testdb3alt OR testdb3alt IS NULL) + + + + + + + + + + + + + + + + + + + + + + + + + + + + testdb3alt + 7 + 2014-12-06T13:41:09.09-08:00 + SQL_Latin1_General_CP1_CI_AS + false + false + false + false + false + 1 +
+
+
+
+ + SqlAuthv12MockTests.SetAzureSqlDatabaseWithSqlAuthv12 + testdb3alt + SELECT DatabasePropertyEx(testdb3alt, 'Edition') as edition, DatabasePropertyEx(testdb3alt, 'MaxSizeInBytes') as maxSizeBytes + + + + + + + + + + + + + + + + + + + + Standard + 107374182400 +
+
+
+
+ + SqlAuthv12MockTests.SetAzureSqlDatabaseWithSqlAuthv12 + master + + SELECT + [db].[name], + [db].[database_id], + [db].[create_date], + [db].[collation_name], + [db].[is_read_only], + [db].[is_query_store_on], + [db].[is_recursive_triggers_on], + [db].[is_federation_member], + CONVERT (bit, CASE WHEN [db].[name] in ('master') THEN 1 ELSE [db].[is_distributor] END) AS [is_system_object], + CONVERT (int, + CASE + WHEN [db].[is_in_standby] = 1 THEN 0x0040 + ELSE 0 + END | + CASE + WHEN [db].[is_cleanly_shutdown] = 1 THEN 0x0080 + ELSE 0 + END | + CASE [db].[state] + WHEN 0 THEN 0x0001 -- NORMAL + WHEN 1 THEN 0x0002 -- RESTORING + WHEN 2 THEN 0x0008 -- RECOVERING + WHEN 3 THEN 0x0004 -- RECOVERY_PENDING + WHEN 4 THEN 0x0010 -- SUSPECT + WHEN 5 THEN 0x0100 -- EMERGENCY + WHEN 6 THEN 0x0020 -- OFFLINE + WHEN 7 THEN 0x0400 -- COPYING + WHEN 9 THEN 0x0800 -- CREATING + WHEN 10 THEN 0x1000 -- OFFLINE_SECONDARY + ELSE 0x0010 -- SUSPECT + END) AS [status] + FROM [sys].[databases] AS [db] + WHERE ([db].[name] = testdb4 OR testdb4 IS NULL) + + + + + + + + + + + + + + + + + + + + + + + + + + + + testdb4 + 8 + 2014-12-06T13:50:25.71-08:00 + SQL_Latin1_General_CP1_CI_AS + false + false + false + false + false + 1 +
+
+
+
+ + SqlAuthv12MockTests.SetAzureSqlDatabaseWithSqlAuthv12 + testdb4 + SELECT DatabasePropertyEx(testdb4, 'Edition') as edition, DatabasePropertyEx(testdb4, 'MaxSizeInBytes') as maxSizeBytes + + + + + + + + + + + + + + + + + + + + Standard + 268435456000 +
+
+
+
diff --git a/src/ServiceManagement/Sql/Commands.SqlDatabase.Test/UnitTests/Database/Cmdlet/NewAzureSqlDatabaseServerContextTests.cs b/src/ServiceManagement/Sql/Commands.SqlDatabase.Test/UnitTests/Database/Cmdlet/NewAzureSqlDatabaseServerContextTests.cs index 8e6774c1f32f..6e9db9966980 100644 --- a/src/ServiceManagement/Sql/Commands.SqlDatabase.Test/UnitTests/Database/Cmdlet/NewAzureSqlDatabaseServerContextTests.cs +++ b/src/ServiceManagement/Sql/Commands.SqlDatabase.Test/UnitTests/Database/Cmdlet/NewAzureSqlDatabaseServerContextTests.cs @@ -365,8 +365,7 @@ public static void CreateServerContextSqlAuthV2( CultureInfo.InvariantCulture, @"{1} = New-AzureSqlDatabaseServerContext " + @"-ManageUrl {0} " + - @"-Credential $credential " + - @"-Version 12.0 ", + @"-Credential $credential ", manageUrl, contextVariable), contextVariable); diff --git a/src/ServiceManagement/Sql/Commands.SqlDatabase.Test/UnitTests/Database/Cmdlet/SqlAuthv12MockTests.cs b/src/ServiceManagement/Sql/Commands.SqlDatabase.Test/UnitTests/Database/Cmdlet/SqlAuthv12MockTests.cs index 456e0fa1268d..3bddaab3c53a 100644 --- a/src/ServiceManagement/Sql/Commands.SqlDatabase.Test/UnitTests/Database/Cmdlet/SqlAuthv12MockTests.cs +++ b/src/ServiceManagement/Sql/Commands.SqlDatabase.Test/UnitTests/Database/Cmdlet/SqlAuthv12MockTests.cs @@ -25,7 +25,6 @@ namespace Microsoft.WindowsAzure.Commands.SqlDatabase.Test.UnitTests.Database.Cmdlet { - [RecordMockDataResults("./")] [TestClass] public class SqlAuthv12MockTests { @@ -36,8 +35,6 @@ public class SqlAuthv12MockTests [TestInitialize] public void Setup() { - var mockConn = new MockSqlConnection(); - TSqlConnectionContext.MockSqlConnection = mockConn; } [TestCleanup] @@ -46,9 +43,13 @@ public void Cleanup() // Do any test clean up here. } + //[RecordMockDataResults("./")] [TestMethod] public void NewAzureSqlDatabaseWithSqlAuthv12() { + var mockConn = new MockSqlConnection(); + TSqlConnectionContext.MockSqlConnection = mockConn; + using (System.Management.Automation.PowerShell powershell = System.Management.Automation.PowerShell.Create()) { @@ -121,6 +122,171 @@ public void NewAzureSqlDatabaseWithSqlAuthv12() } + //[RecordMockDataResults("./")] + [TestMethod] + public void GetAzureSqlDatabaseWithSqlAuthv12() + { + var mockConn = new MockSqlConnection(); + TSqlConnectionContext.MockSqlConnection = mockConn; + + using (System.Management.Automation.PowerShell powershell = + System.Management.Automation.PowerShell.Create()) + { + // Create a context + NewAzureSqlDatabaseServerContextTests.CreateServerContextSqlAuthV2( + powershell, + manageUrl, + username, + password, + "$context"); + + Collection database1, database2, database3; + + database1 = powershell.InvokeBatchScript( + @"$testdb1 = Get-AzureSqlDatabase " + + @"-Context $context " + + @"-DatabaseName testdb1 ", + @"$testdb1"); + database2 = powershell.InvokeBatchScript( + @"$testdb2 = Get-AzureSqlDatabase " + + @"-Context $context " + + @"-Database $testdb1 ", + @"$testdb2"); + database3 = powershell.InvokeBatchScript( + @"$testdb3 = Get-AzureSqlDatabase " + + @"-Context $context ", + @"$testdb3"); + + Assert.AreEqual(0, powershell.Streams.Error.Count, "Errors during run!"); + Assert.AreEqual(0, powershell.Streams.Warning.Count, "Warnings during run!"); + powershell.Streams.ClearStreams(); + + Services.Server.Database database = database1.Single().BaseObject as Services.Server.Database; + Assert.IsTrue(database != null, "Expecting a Database object"); + ValidateDatabaseProperties(database, "testdb1", "Standard", 250, 268435456000L, "SQL_Latin1_General_CP1_CI_AS", false, DatabaseTestHelper.StandardS0SloGuid); + + database = database2.Single().BaseObject as Services.Server.Database; + Assert.IsTrue(database != null, "Expecting a Database object"); + ValidateDatabaseProperties(database, "testdb1", "Standard", 250, 268435456000L, "SQL_Latin1_General_CP1_CI_AS", false, DatabaseTestHelper.StandardS0SloGuid); + + Assert.IsTrue(database3.Count == 5); + foreach (var entry in database3) + { + var db = entry.BaseObject as Services.Server.Database; + Assert.IsTrue(db != null, "Expecting a Database object"); + } + } + } + + //[RecordMockDataResults("./")] + [TestMethod] + public void SetAzureSqlDatabaseWithSqlAuthv12() + { + var mockConn = new MockSqlConnection(); + TSqlConnectionContext.MockSqlConnection = mockConn; + + using (System.Management.Automation.PowerShell powershell = + System.Management.Automation.PowerShell.Create()) + { + // Create a context + NewAzureSqlDatabaseServerContextTests.CreateServerContextSqlAuthV2( + powershell, + manageUrl, + username, + password, + "$context"); + + Collection database1, database2, database3, database4; + + database1 = powershell.InvokeBatchScript( + @"$testdb1 = Set-AzureSqlDatabase " + + @"-Context $context " + + @"-DatabaseName testdb1 " + + @"-Edition Basic " + + @"-MaxSizeGb 1 " + + @"-Force " + + @"-PassThru ", + @"$testdb1"); + database2 = powershell.InvokeBatchScript( + @"$testdb2 = Set-AzureSqlDatabase " + + @"-Context $context " + + @"-DatabaseName testdb2 " + + @"-Edition Standard " + + @"-MaxSizeBytes 107374182400 " + + @"-Force " + + @"-PassThru ", + @"$testdb2"); + database3 = powershell.InvokeBatchScript( + @"$testdb3 = Set-AzureSqlDatabase " + + @"-Context $context " + + @"-DatabaseName testdb3 " + + @"-NewDatabaseName testdb3alt " + + @"-Force " + + @"-PassThru ", + @"$testdb3"); + var slo = powershell.InvokeBatchScript( + @"$so = Get-AzureSqlDatabaseServiceObjective " + + @"-Context $context " + + @"-ServiceObjectiveName S0 ", + @"$so"); + database4 = powershell.InvokeBatchScript( + @"$testdb4 = Set-AzureSqlDatabase " + + @"-Context $context " + + @"-DatabaseName testdb4 " + + @"-ServiceObjective $so " + + @"-Force " + + @"-PassThru ", + @"$testdb4"); + + // + // Wait for operations to complete + // + + database1 = powershell.InvokeBatchScript( + @"$testdb1 = Get-AzureSqlDatabase " + + @"-Context $context " + + @"-DatabaseName testdb1 ", + @"$testdb1"); + database2 = powershell.InvokeBatchScript( + @"$testdb2 = Get-AzureSqlDatabase " + + @"-Context $context " + + @"-DatabaseName testdb2 ", + @"$testdb2"); + database3 = powershell.InvokeBatchScript( + @"$testdb3 = Get-AzureSqlDatabase " + + @"-Context $context " + + @"-DatabaseName testdb3alt ", + @"$testdb3"); + database4 = powershell.InvokeBatchScript( + @"$testdb4 = Get-AzureSqlDatabase " + + @"-Context $context " + + @"-DatabaseName testdb4 ", + @"$testdb4"); + + Assert.AreEqual(0, powershell.Streams.Error.Count, "Errors during run!"); + Assert.AreEqual(0, powershell.Streams.Warning.Count, "Warnings during run!"); + powershell.Streams.ClearStreams(); + + Services.Server.Database database = database1.Single().BaseObject as Services.Server.Database; + Assert.IsTrue(database != null, "Expecting a Database object"); + ValidateDatabaseProperties(database, "testdb1", "Basic", 1, 1073741824L, "SQL_Latin1_General_CP1_CI_AS", false, DatabaseTestHelper.BasicSloGuid); + + database = database2.Single().BaseObject as Services.Server.Database; + Assert.IsTrue(database != null, "Expecting a Database object"); + ValidateDatabaseProperties(database, "testdb2", "Standard", 100, 107374182400L, "Japanese_CI_AS", false, DatabaseTestHelper.StandardS0SloGuid); + + database = database3.Single().BaseObject as Services.Server.Database; + Assert.IsTrue(database != null, "Expecting a Database object"); + ValidateDatabaseProperties(database, "testdb3alt", "Standard", 100, 107374182400L, "SQL_Latin1_General_CP1_CI_AS", false, DatabaseTestHelper.StandardS0SloGuid); + + database = database4.Single().BaseObject as Services.Server.Database; + Assert.IsTrue(database != null, "Expecting a Database object"); + ValidateDatabaseProperties(database, "testdb4", "Standard", 250, 268435456000L, "SQL_Latin1_General_CP1_CI_AS", false, DatabaseTestHelper.StandardS0SloGuid); + } + } + + #region Helpers + /// /// Validate the properties of a database against the expected values supplied as input. /// @@ -148,5 +314,7 @@ internal static void ValidateDatabaseProperties( Assert.AreEqual(isSystem, database.IsSystemObject); // Assert.AreEqual(slo, database.ServiceObjectiveId); } + + #endregion } } diff --git a/src/ServiceManagement/Sql/Commands.SqlDatabase.Test/UnitTests/TSql/MockSettings.cs b/src/ServiceManagement/Sql/Commands.SqlDatabase.Test/UnitTests/TSql/MockSettings.cs index e710e0ca2e82..0af96e1601a7 100644 --- a/src/ServiceManagement/Sql/Commands.SqlDatabase.Test/UnitTests/TSql/MockSettings.cs +++ b/src/ServiceManagement/Sql/Commands.SqlDatabase.Test/UnitTests/TSql/MockSettings.cs @@ -26,6 +26,9 @@ namespace Microsoft.WindowsAzure.Commands.SqlDatabase.Test.UnitTests.TSql { internal sealed class MockSettings { + /// + /// Gets or sets the id for the mock + /// private string mockId; /// @@ -143,8 +146,6 @@ from StackFrame frame in stackFrames if (testMethodFrame != null) { settings.mockId = GetMockId(testMethodFrame); - //settings.initializeMethod = FindMockSetupMethod(testMethodFrame); - //settings.cleanupMethod = FindMockSetupMethod(testMethodFrame); RecordMockDataResultsAttribute recordAttr = FindRecordMockDataResultsAttribute(testMethodFrame); if (recordAttr != null) @@ -197,40 +198,16 @@ private static RecordMockDataResultsAttribute FindRecordMockDataResultsAttribute return recordAttr; } - //private static SetupMethodDelegate FindMockSetupMethod(StackFrame testMethodFrame) - // where T : Attribute - //{ - // Type declaringType = testMethodFrame.GetMethod().DeclaringType; - // MethodInfo[] methods = declaringType.GetMethods(); - - // foreach (MethodInfo method in methods) - // { - // if (method.GetCustomAttribute() != null) - // { - // if (!method.IsStatic) - // { - // throw new NotSupportedException("Non-static mock setup method are not supported."); - // } - - // return delegate(SqlConnection connection) - // { - // method.Invoke(null, new object[] { connection }); - // }; - // } - // } - - // return null; - //} - private static string GetMockId(StackFrame testMethodFrame) { List parts = new List(); + parts.Insert(0, testMethodFrame.GetMethod().Name); for (Type type = testMethodFrame.GetMethod().DeclaringType; type != null; type = type.DeclaringType) { parts.Insert(0, type.Name); } - + return String.Join(".", parts.ToArray()); } } diff --git a/src/ServiceManagement/Sql/Commands.SqlDatabase.Test/UnitTests/TSql/MockSqlCommand.cs b/src/ServiceManagement/Sql/Commands.SqlDatabase.Test/UnitTests/TSql/MockSqlCommand.cs index a3ebd597430b..6fce34eac110 100644 --- a/src/ServiceManagement/Sql/Commands.SqlDatabase.Test/UnitTests/TSql/MockSqlCommand.cs +++ b/src/ServiceManagement/Sql/Commands.SqlDatabase.Test/UnitTests/TSql/MockSqlCommand.cs @@ -173,24 +173,28 @@ public override bool DesignTimeVisible protected override DbDataReader ExecuteDbDataReader(CommandBehavior behavior) { Assert.IsTrue((this.Connection.State & ConnectionState.Open) == ConnectionState.Open, "Connection has to be opened when executing a command"); + + MockQueryResult mockResult = null; - string commandKey = this.GetCommandKey(); - MockQueryResult mockResult = FindMockResult(this.settings.MockId, this.Connection.Database, commandKey, this.settings.IsolatedQueries); - - if (mockResult == null && this.settings.RecordingMode) + if (this.settings.RecordingMode) { mockResult = this.RecordExecuteDbDataReader(); } - - if (mockResult == null || mockResult.DataSetResult == null) + else { - if (mockResult != null && mockResult.ExceptionResult != null) - { - throw mockResult.ExceptionResult.Exception; - } - else + string commandKey = this.GetCommandKey(); + mockResult = FindMockResult(this.settings.MockId, this.Connection.Database, commandKey, this.settings.IsolatedQueries); + + if (mockResult == null || mockResult.DataSetResult == null) { - throw new NotSupportedException(string.Format("Mock SqlConnection does not know how to handle query: '{0}'", commandKey)); + if (mockResult != null && mockResult.ExceptionResult != null) + { + throw mockResult.ExceptionResult.Exception; + } + else + { + throw new NotSupportedException(string.Format("Mock SqlConnection does not know how to handle query: '{0}'", commandKey)); + } } } @@ -244,13 +248,17 @@ public override object ExecuteScalar() { Assert.IsTrue((this.Connection.State & ConnectionState.Open) == ConnectionState.Open, "Connection has to be opened when executing command"); - string commandKey = this.GetCommandKey(); - MockQueryResult mockResult = FindMockResult(this.settings.MockId, this.Connection.Database, commandKey, this.settings.IsolatedQueries); + MockQueryResult mockResult = null; - if (mockResult == null && this.settings.RecordingMode) + if (this.settings.RecordingMode) { mockResult = this.RecordExecuteScalar(); } + else + { + string commandKey = this.GetCommandKey(); + FindMockResult(this.settings.MockId, this.Connection.Database, commandKey, this.settings.IsolatedQueries); + } return mockResult != null ? mockResult.ScalarResult : null; } @@ -350,7 +358,14 @@ private MockQueryResult RecordExecuteDbDataReader() foreach (DbParameter param in this.Parameters) { SqlParameter sqlParam = new SqlParameter(param.ParameterName, param.DbType); - sqlParam.Value = param.Value; + if (param.Value == null) + { + sqlParam.Value = DBNull.Value; + } + else + { + sqlParam.Value = param.Value; + } cmd.Parameters.Add(sqlParam); } @@ -473,7 +488,7 @@ private string GetCommandKey() key = key.Replace(parameter.ParameterName, value); } - key = key.Replace("\r", string.Empty).Replace("\n", Environment.NewLine); + //key = key.Replace("\r", string.Empty).Replace("\n", Environment.NewLine); key = TempTableNameRegex.Replace(key, TempTableName); @@ -599,25 +614,6 @@ private static void InitializeMockResults() } } } - - //foreach (string rn in resourceNames) - //{ - // using (Stream stream = Assembly.GetExecutingAssembly().GetManifestResourceStream(rn)) - // { - // if (stream.Length > 0) - // { - // MockQueryResultSet mockResultSet = MockQueryResultSet.Deserialize(stream); - - // if (mockResultSet.CommandResults != null) - // { - // foreach (MockQueryResult mockResult in mockResultSet.CommandResults) - // { - // AddMockResult(mockResult); - // } - // } - // } - // } - //} } #endregion diff --git a/src/ServiceManagement/Sql/Commands.SqlDatabase.Test/UnitTests/TSql/MockSqlConnection.cs b/src/ServiceManagement/Sql/Commands.SqlDatabase.Test/UnitTests/TSql/MockSqlConnection.cs index 52b0c93d47c0..8a821bb017d4 100644 --- a/src/ServiceManagement/Sql/Commands.SqlDatabase.Test/UnitTests/TSql/MockSqlConnection.cs +++ b/src/ServiceManagement/Sql/Commands.SqlDatabase.Test/UnitTests/TSql/MockSqlConnection.cs @@ -53,15 +53,6 @@ public static DbConnection CreateConnection(string connectionString) /// private void InitializeMockEnvironment() { - //if (this.settings.RecordingMode && this.settings.InitializeMethod != null) - //{ - // using (SqlConnection connection = new SqlConnection(this.settings.SqlConnectionString)) - // { - // connection.Open(); - - // this.settings.InitializeMethod(connection); - // } - //} } /// @@ -69,15 +60,6 @@ private void InitializeMockEnvironment() /// private void CleanupMockEnvironment() { - //if (this.settings.RecordingMode && this.settings.CleanupMethod != null) - //{ - // using (SqlConnection connection = new SqlConnection(this.settings.SqlConnectionString)) - // { - // connection.Open(); - - // this.settings.CleanupMethod(connection); - // } - //} } /// diff --git a/src/ServiceManagement/Sql/Commands.SqlDatabase/Services/Server/TSqlConnectionContext.cs b/src/ServiceManagement/Sql/Commands.SqlDatabase/Services/Server/TSqlConnectionContext.cs index 6c84fc99ce98..2c044d5991ed 100644 --- a/src/ServiceManagement/Sql/Commands.SqlDatabase/Services/Server/TSqlConnectionContext.cs +++ b/src/ServiceManagement/Sql/Commands.SqlDatabase/Services/Server/TSqlConnectionContext.cs @@ -210,7 +210,8 @@ public Database[] GetDatabases() command.CommandText = getDatabaseQuery; DbParameter param = command.CreateParameter(); param.ParameterName = "@name"; - param.Value = DBNull.Value; + param.Value = null; + command.Parameters.Add(param); connection.Open(); From 6b567189dbf53e408929e3c656dfb994e9b3fc02 Mon Sep 17 00:00:00 2001 From: adamkr Date: Mon, 8 Dec 2014 08:32:37 -0800 Subject: [PATCH 5/6] Adding some additional checks for collation name. --- .../Services/Server/TSqlConnectionContext.cs | 21 ++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/src/ServiceManagement/Sql/Commands.SqlDatabase/Services/Server/TSqlConnectionContext.cs b/src/ServiceManagement/Sql/Commands.SqlDatabase/Services/Server/TSqlConnectionContext.cs index 2c044d5991ed..a6bf18160394 100644 --- a/src/ServiceManagement/Sql/Commands.SqlDatabase/Services/Server/TSqlConnectionContext.cs +++ b/src/ServiceManagement/Sql/Commands.SqlDatabase/Services/Server/TSqlConnectionContext.cs @@ -381,10 +381,12 @@ public Database CreateNewDatabase( } } + SqlCollationCheck(databaseCollation); + commandText = string.Format( commandText, SqlEscape(databaseName), - SqlEscape(databaseCollation), + databaseCollation, SqlEscape(maxSizeVal), SqlEscape(databaseEdition.ToString()), SqlEscape(serviceObjectiveName)); @@ -405,6 +407,23 @@ public Database CreateNewDatabase( return GetDatabase(databaseName); } + private void SqlCollationCheck(string databaseCollation) + { + bool isValid = databaseCollation.All( (c) => + { + if(!char.IsLetterOrDigit(c) && c != '_') + { + return false; + } + return true; + }); + + if(!isValid) + { + throw new ArgumentException("Invalid Collation", "Collation"); + } + } + /// /// Updates the property on the database with the name . /// From 8c00f011930b1d154ffa2f56e55fe0e4c158a38c Mon Sep 17 00:00:00 2001 From: adamkr Date: Mon, 8 Dec 2014 13:08:16 -0800 Subject: [PATCH 6/6] Fixing some unit tests. --- .../Cmdlet/ImportExportCmdletTests.cs | 5 ++ .../NewAzureSqlDatabaseServerContextTests.cs | 21 +++++- .../Database/Cmdlet/SqlAuthv12MockTests.cs | 10 +-- .../UnitTests/TSql/MockSqlCommand.cs | 2 - .../NewAzureSqlDatabaseServerContext.cs | 1 - .../Services/Server/SqlAuthContextFactory.cs | 65 ++++++++++++++++++- .../Services/Server/TSqlConnectionContext.cs | 9 +++ 7 files changed, 102 insertions(+), 11 deletions(-) diff --git a/src/ServiceManagement/Sql/Commands.SqlDatabase.Test/UnitTests/Database/Cmdlet/ImportExportCmdletTests.cs b/src/ServiceManagement/Sql/Commands.SqlDatabase.Test/UnitTests/Database/Cmdlet/ImportExportCmdletTests.cs index 2c61a4bcb0bb..433dc8ecb49d 100644 --- a/src/ServiceManagement/Sql/Commands.SqlDatabase.Test/UnitTests/Database/Cmdlet/ImportExportCmdletTests.cs +++ b/src/ServiceManagement/Sql/Commands.SqlDatabase.Test/UnitTests/Database/Cmdlet/ImportExportCmdletTests.cs @@ -24,6 +24,7 @@ using Microsoft.WindowsAzure.Commands.SqlDatabase.Test.Utilities; using Microsoft.WindowsAzure.Commands.Test.Utilities.Common; using Microsoft.WindowsAzure.Commands.Utilities.Common; +using Microsoft.WindowsAzure.Commands.SqlDatabase.Services.Server; namespace Microsoft.WindowsAzure.Commands.SqlDatabase.Test.UnitTests.Database.Cmdlet { @@ -112,6 +113,10 @@ public void ImportExportAzureSqlDatabaseTests() @" -StorageAccountKey $storageAccountKey"); }).FirstOrDefault(); + // Tell the sql auth factory to create a v2 context (skip checking sql version using select query). + // + SqlAuthContextFactory.sqlVersion = SqlAuthContextFactory.SqlVersion.v2; + //testSession.ServiceBaseUri = new Uri("https://lqtqbo6kkp.database.windows.net"); Collection databaseContext = MockServerHelper.ExecuteWithMock( testSession, diff --git a/src/ServiceManagement/Sql/Commands.SqlDatabase.Test/UnitTests/Database/Cmdlet/NewAzureSqlDatabaseServerContextTests.cs b/src/ServiceManagement/Sql/Commands.SqlDatabase.Test/UnitTests/Database/Cmdlet/NewAzureSqlDatabaseServerContextTests.cs index 6e9db9966980..9d81af0d8211 100644 --- a/src/ServiceManagement/Sql/Commands.SqlDatabase.Test/UnitTests/Database/Cmdlet/NewAzureSqlDatabaseServerContextTests.cs +++ b/src/ServiceManagement/Sql/Commands.SqlDatabase.Test/UnitTests/Database/Cmdlet/NewAzureSqlDatabaseServerContextTests.cs @@ -126,6 +126,9 @@ public void NewAzureSqlDatabaseServerContextWithSqlAuth() UnitTestHelper.ImportAzureModule(powershell); UnitTestHelper.CreateTestCredential(powershell); + // Tell the sql auth factory to create a v2 context (skip checking sql version using select query). + // + SqlAuthContextFactory.sqlVersion = SqlAuthContextFactory.SqlVersion.v2; using (AsyncExceptionManager exceptionManager = new AsyncExceptionManager()) { Collection serverContext; @@ -190,6 +193,10 @@ public void NewAzureSqlDatabaseServerContextWithSqlAuthNegativeCases() using (AsyncExceptionManager exceptionManager = new AsyncExceptionManager()) { + // Tell the sql auth factory to create a v2 context (skip checking sql version using select query). + // + SqlAuthContextFactory.sqlVersion = SqlAuthContextFactory.SqlVersion.v2; + // Test warning when different $metadata is received. Collection serverContext; using (new MockHttpServer( @@ -212,6 +219,10 @@ public void NewAzureSqlDatabaseServerContextWithSqlAuthNegativeCases() Assert.AreEqual(2, powershell.Streams.Warning.Count, "Should have warning!"); powershell.Streams.ClearStreams(); + // Tell the sql auth factory to create a v2 context (skip checking sql version using select query). + // + SqlAuthContextFactory.sqlVersion = SqlAuthContextFactory.SqlVersion.v2; + // Test error case using (new MockHttpServer( exceptionManager, @@ -308,6 +319,10 @@ public static void CreateServerContextSqlAuth( testSession.SessionProperties["Username"], testSession.SessionProperties["Password"]); + // Tell the sql auth factory to create a v2 context (skip checking sql version using select query). + // + SqlAuthContextFactory.sqlVersion = SqlAuthContextFactory.SqlVersion.v2; + Collection serverContext; using (AsyncExceptionManager exceptionManager = new AsyncExceptionManager()) { @@ -344,7 +359,7 @@ public static void CreateServerContextSqlAuth( /// Common helper method for other tests to create a context for ESA server. /// /// The variable name that will hold the new context. - public static void CreateServerContextSqlAuthV2( + public static void CreateServerContextSqlAuthV12( System.Management.Automation.PowerShell powershell, string manageUrl, string username, @@ -357,6 +372,10 @@ public static void CreateServerContextSqlAuthV2( username, password); + // Tell the sql auth factory to create a v22 context (skip checking sql version using select query). + // + SqlAuthContextFactory.sqlVersion = SqlAuthContextFactory.SqlVersion.v12; + Collection serverContext; using (AsyncExceptionManager exceptionManager = new AsyncExceptionManager()) { diff --git a/src/ServiceManagement/Sql/Commands.SqlDatabase.Test/UnitTests/Database/Cmdlet/SqlAuthv12MockTests.cs b/src/ServiceManagement/Sql/Commands.SqlDatabase.Test/UnitTests/Database/Cmdlet/SqlAuthv12MockTests.cs index 3bddaab3c53a..38b2ce5a8e60 100644 --- a/src/ServiceManagement/Sql/Commands.SqlDatabase.Test/UnitTests/Database/Cmdlet/SqlAuthv12MockTests.cs +++ b/src/ServiceManagement/Sql/Commands.SqlDatabase.Test/UnitTests/Database/Cmdlet/SqlAuthv12MockTests.cs @@ -13,6 +13,7 @@ // --- using Microsoft.VisualStudio.TestTools.UnitTesting; +using Microsoft.WindowsAzure.Commands.SqlDatabase.Database.Cmdlet; using Microsoft.WindowsAzure.Commands.SqlDatabase.Services.Server; using Microsoft.WindowsAzure.Commands.SqlDatabase.Test.UnitTests.TSql; using System; @@ -30,7 +31,7 @@ public class SqlAuthv12MockTests { public static string username = "testlogin"; public static string password = "MyS3curePa$$w0rd"; - public static string manageUrl = "https://mysvr2.adamkr-vm04.onebox.xdb.mscds.com"; + public static string manageUrl = "https://mysvr2.database.windows.net"; [TestInitialize] public void Setup() @@ -54,7 +55,7 @@ public void NewAzureSqlDatabaseWithSqlAuthv12() System.Management.Automation.PowerShell.Create()) { // Create a context - NewAzureSqlDatabaseServerContextTests.CreateServerContextSqlAuthV2( + NewAzureSqlDatabaseServerContextTests.CreateServerContextSqlAuthV12( powershell, manageUrl, username, @@ -121,7 +122,6 @@ public void NewAzureSqlDatabaseWithSqlAuthv12() } } - //[RecordMockDataResults("./")] [TestMethod] public void GetAzureSqlDatabaseWithSqlAuthv12() @@ -133,7 +133,7 @@ public void GetAzureSqlDatabaseWithSqlAuthv12() System.Management.Automation.PowerShell.Create()) { // Create a context - NewAzureSqlDatabaseServerContextTests.CreateServerContextSqlAuthV2( + NewAzureSqlDatabaseServerContextTests.CreateServerContextSqlAuthV12( powershell, manageUrl, username, @@ -189,7 +189,7 @@ public void SetAzureSqlDatabaseWithSqlAuthv12() System.Management.Automation.PowerShell.Create()) { // Create a context - NewAzureSqlDatabaseServerContextTests.CreateServerContextSqlAuthV2( + NewAzureSqlDatabaseServerContextTests.CreateServerContextSqlAuthV12( powershell, manageUrl, username, diff --git a/src/ServiceManagement/Sql/Commands.SqlDatabase.Test/UnitTests/TSql/MockSqlCommand.cs b/src/ServiceManagement/Sql/Commands.SqlDatabase.Test/UnitTests/TSql/MockSqlCommand.cs index 6fce34eac110..77044bea8a77 100644 --- a/src/ServiceManagement/Sql/Commands.SqlDatabase.Test/UnitTests/TSql/MockSqlCommand.cs +++ b/src/ServiceManagement/Sql/Commands.SqlDatabase.Test/UnitTests/TSql/MockSqlCommand.cs @@ -488,8 +488,6 @@ private string GetCommandKey() key = key.Replace(parameter.ParameterName, value); } - //key = key.Replace("\r", string.Empty).Replace("\n", Environment.NewLine); - key = TempTableNameRegex.Replace(key, TempTableName); return key; diff --git a/src/ServiceManagement/Sql/Commands.SqlDatabase/Database/Cmdlet/NewAzureSqlDatabaseServerContext.cs b/src/ServiceManagement/Sql/Commands.SqlDatabase/Database/Cmdlet/NewAzureSqlDatabaseServerContext.cs index bcf2cffbe686..f9ff613afc93 100644 --- a/src/ServiceManagement/Sql/Commands.SqlDatabase/Database/Cmdlet/NewAzureSqlDatabaseServerContext.cs +++ b/src/ServiceManagement/Sql/Commands.SqlDatabase/Database/Cmdlet/NewAzureSqlDatabaseServerContext.cs @@ -162,7 +162,6 @@ private AzureSubscription CurrentSubscription #endregion - /// /// Connect to a Azure SQL Server with the given ManagementService Uri using /// SQL authentication credentials. diff --git a/src/ServiceManagement/Sql/Commands.SqlDatabase/Services/Server/SqlAuthContextFactory.cs b/src/ServiceManagement/Sql/Commands.SqlDatabase/Services/Server/SqlAuthContextFactory.cs index 0cc49e16ee9d..ec28f7dfc2a7 100644 --- a/src/ServiceManagement/Sql/Commands.SqlDatabase/Services/Server/SqlAuthContextFactory.cs +++ b/src/ServiceManagement/Sql/Commands.SqlDatabase/Services/Server/SqlAuthContextFactory.cs @@ -1,4 +1,18 @@ -using Microsoft.WindowsAzure.Commands.SqlDatabase.Properties; +// ---------------------------------------------------------------------------------- +// +// Copyright Microsoft Corporation +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ---------------------------------------------------------------------------------- + +using Microsoft.WindowsAzure.Commands.SqlDatabase.Properties; using Microsoft.WindowsAzure.Commands.SqlDatabase.Services.Common; using System; using System.Collections.Generic; @@ -14,6 +28,38 @@ namespace Microsoft.WindowsAzure.Commands.SqlDatabase.Services.Server { public class SqlAuthContextFactory { + /// + /// The different sql versions available + /// + internal enum SqlVersion + { + /// + /// Not set. Determine by querying the server + /// + None, + + /// + /// V2 server + /// + v2, + + /// + /// V12 server + /// + v12 + } + internal static SqlVersion sqlVersion = SqlVersion.None; + + /// + /// Gets a sql auth connection context. + /// + /// The cmdlet requesting the context + /// The name of the server to connect to + /// The manage url of the server + /// The credentials to connect to the server + /// The session activity ID + /// The URI for management service + /// The connection context public static IServerDataServiceContext GetContext( PSCmdlet cmdlet, string serverName, @@ -22,7 +68,22 @@ public static IServerDataServiceContext GetContext( Guid sessionActivityId, Uri managementServiceUri) { - Version version = GetVersion(manageUrl, credentials); + Version version; + + // If a version was specified (by tests) us it. + if (sqlVersion == SqlVersion.v2) + { + version = new Version(11, 0); + } + else if (sqlVersion == SqlVersion.v12) + { + version = new Version(12, 0); + } + else // If no version specified, determine the version by querying the server. + { + version = GetVersion(manageUrl, credentials); + } + sqlVersion = SqlVersion.None; IServerDataServiceContext context = null; diff --git a/src/ServiceManagement/Sql/Commands.SqlDatabase/Services/Server/TSqlConnectionContext.cs b/src/ServiceManagement/Sql/Commands.SqlDatabase/Services/Server/TSqlConnectionContext.cs index a6bf18160394..51faaac957f8 100644 --- a/src/ServiceManagement/Sql/Commands.SqlDatabase/Services/Server/TSqlConnectionContext.cs +++ b/src/ServiceManagement/Sql/Commands.SqlDatabase/Services/Server/TSqlConnectionContext.cs @@ -407,8 +407,17 @@ public Database CreateNewDatabase( return GetDatabase(databaseName); } + /// + /// Checks to make sure the collation only contains alphanumeric characters and '_' + /// + /// The string to verify private void SqlCollationCheck(string databaseCollation) { + if(string.IsNullOrEmpty (databaseCollation)) + { + return; + } + bool isValid = databaseCollation.All( (c) => { if(!char.IsLetterOrDigit(c) && c != '_')