From 45ed3c43f491f0fe490c7315b1b833befa8be2e1 Mon Sep 17 00:00:00 2001 From: Becca Petrin Date: Tue, 23 Apr 2019 09:06:14 -0700 Subject: [PATCH] Merge pull request #6356 from kedarkale27/master Update mssql.go --- physical/mssql/mssql.go | 7 ++--- physical/mssql/mssql_test.go | 58 ++++++++++++++++++++++++++++++++++++ 2 files changed, 60 insertions(+), 5 deletions(-) diff --git a/physical/mssql/mssql.go b/physical/mssql/mssql.go index d0029dd160a8..6d611f7fe3e8 100644 --- a/physical/mssql/mssql.go +++ b/physical/mssql/mssql.go @@ -123,16 +123,13 @@ func NewMSSQLBackend(conf map[string]string, logger log.Logger) (physical.Backen "') CREATE TABLE " + dbTable + " (Path VARCHAR(512) PRIMARY KEY, Value VARBINARY(MAX))" if schema != "dbo" { - if _, err := db.Exec("USE " + database); err != nil { - return nil, errwrap.Wrapf("failed to switch mssql database: {{err}}", err) - } var num int - err = db.QueryRow("SELECT 1 FROM sys.schemas WHERE name = '" + schema + "'").Scan(&num) + err = db.QueryRow("SELECT 1 FROM " + database + ".sys.schemas WHERE name = '" + schema + "'").Scan(&num) switch { case err == sql.ErrNoRows: - if _, err := db.Exec("CREATE SCHEMA " + schema); err != nil { + if _, err := db.Exec("USE " + database + "; EXEC ('CREATE SCHEMA " + schema + "')"); err != nil { return nil, errwrap.Wrapf("failed to create mssql schema: {{err}}", err) } diff --git a/physical/mssql/mssql_test.go b/physical/mssql/mssql_test.go index 9c55228018fe..a98165bdb7b5 100644 --- a/physical/mssql/mssql_test.go +++ b/physical/mssql/mssql_test.go @@ -27,6 +27,63 @@ func TestMSSQLBackend(t *testing.T) { table = "test" } + schema := os.Getenv("MSSQL_SCHEMA") + if schema == "" { + schema = "test" + } + + username := os.Getenv("MSSQL_USERNAME") + password := os.Getenv("MSSQL_PASSWORD") + + // Run vault tests + logger := logging.NewVaultLogger(log.Debug) + + b, err := NewMSSQLBackend(map[string]string{ + "server": server, + "database": database, + "table": table, + "schema": schema, + "username": username, + "password": password, + }, logger) + + if err != nil { + t.Fatalf("Failed to create new backend: %v", err) + } + + defer func() { + mssql := b.(*MSSQLBackend) + _, err := mssql.client.Exec("DROP TABLE " + mssql.dbTable) + if err != nil { + t.Fatalf("Failed to drop table: %v", err) + } + }() + + physical.ExerciseBackend(t, b) + physical.ExerciseBackend_ListPrefix(t, b) +} + +func TestMSSQLBackend_schema(t *testing.T) { + server := os.Getenv("MSSQL_SERVER") + if server == "" { + t.SkipNow() + } + + database := os.Getenv("MSSQL_DB") + if database == "" { + database = "test" + } + + table := os.Getenv("MSSQL_TABLE") + if table == "" { + table = "test" + } + + schema := os.Getenv("MSSQL_SCHEMA") + if schema == "" { + schema = "test" + } + username := os.Getenv("MSSQL_USERNAME") password := os.Getenv("MSSQL_PASSWORD") @@ -36,6 +93,7 @@ func TestMSSQLBackend(t *testing.T) { b, err := NewMSSQLBackend(map[string]string{ "server": server, "database": database, + "schema": schema, "table": table, "username": username, "password": password,