Skip to content

Commit

Permalink
Merge pull request #227 from zhevron/mssql-params
Browse files Browse the repository at this point in the history
Fix incorrect syntax errors for MSSQL driver
  • Loading branch information
dhui authored May 27, 2019
2 parents e877644 + dd0ead0 commit 8437fe6
Show file tree
Hide file tree
Showing 22 changed files with 43 additions and 39 deletions.
2 changes: 1 addition & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ WORKDIR /go/src/github.com/golang-migrate/migrate
COPY . ./

ENV GO111MODULE=on
ENV DATABASES="postgres mysql redshift cassandra spanner cockroachdb clickhouse mongodb mssql"
ENV DATABASES="postgres mysql redshift cassandra spanner cockroachdb clickhouse mongodb sqlserver"
ENV SOURCES="file go_bindata github aws_s3 google_cloud_storage godoc_vfs gitlab"

RUN go build -a -o build/migrate.linux-386 -ldflags="-X main.Version=${VERSION}" -tags "$DATABASES $SOURCES" ./cmd/migrate
Expand Down
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
SOURCE ?= file go_bindata github aws_s3 google_cloud_storage godoc_vfs gitlab
DATABASE ?= postgres mysql redshift cassandra spanner cockroachdb clickhouse mongodb mssql
DATABASE ?= postgres mysql redshift cassandra spanner cockroachdb clickhouse mongodb sqlserver
VERSION ?= $(shell git describe --tags 2>/dev/null | cut -c 2-)
TEST_FLAGS ?=
REPO_OWNER ?= $(shell cd .. && basename "$$(pwd)")
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ Database drivers run migrations. [Add a new database?](database/driver.go)
* [CockroachDB](database/cockroachdb)
* [ClickHouse](database/clickhouse)
* [Firebird](database/firebird) ([todo #49](https://github.com/golang-migrate/migrate/issues/49))
* [MS SQL Server](database/mssql)
* [MS SQL Server](database/sqlserver)

### Database URLs

Expand Down
4 changes: 4 additions & 0 deletions database/mssql/README.md → database/sqlserver/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,7 @@

See https://github.com/denisenkom/go-mssqldb for full parameter list.

## Note about driver support

Please note that the deprecated `mssql` driver is not supported. Please use the newer `sqlserver` driver.
See https://github.com/denisenkom/go-mssqldb#deprecated for more information.
42 changes: 21 additions & 21 deletions database/mssql/mssql.go → database/sqlserver/sqlserver.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package mssql
package sqlserver

import (
"context"
Expand All @@ -15,9 +15,7 @@ import (
)

func init() {
db := MSSQL{}
database.Register("mssql", &db)
database.Register("sqlserver", &db)
database.Register("sqlserver", &SQLServer{})
}

// DefaultMigrationsTable is the name of the migrations table in the database
Expand All @@ -44,8 +42,8 @@ type Config struct {
SchemaName string
}

// MSSQL connection
type MSSQL struct {
// SQL Server connection
type SQLServer struct {
// Locking and unlocking need to use the same connection
conn *sql.Conn
db *sql.DB
Expand All @@ -55,7 +53,9 @@ type MSSQL struct {
config *Config
}

// WithInstance returns a database instance from an already created database connection
// WithInstance returns a database instance from an already created database connection.
//
// Note that the deprecated `mssql` driver is not supported. Please use the newer `sqlserver` driver.
func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) {
if config == nil {
return nil, ErrNilConfig
Expand Down Expand Up @@ -99,7 +99,7 @@ func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) {
return nil, err
}

ss := &MSSQL{
ss := &SQLServer{
conn: conn,
db: instance,
config: config,
Expand All @@ -113,13 +113,13 @@ func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) {
}

// Open a connection to the database
func (ss *MSSQL) Open(url string) (database.Driver, error) {
func (ss *SQLServer) Open(url string) (database.Driver, error) {
purl, err := nurl.Parse(url)
if err != nil {
return nil, err
}

db, err := sql.Open("mssql", migrate.FilterCustomQuery(purl).String())
db, err := sql.Open("sqlserver", migrate.FilterCustomQuery(purl).String())
if err != nil {
return nil, err
}
Expand All @@ -139,7 +139,7 @@ func (ss *MSSQL) Open(url string) (database.Driver, error) {
}

// Close the database connection
func (ss *MSSQL) Close() error {
func (ss *SQLServer) Close() error {
connErr := ss.conn.Close()
dbErr := ss.db.Close()
if connErr != nil || dbErr != nil {
Expand All @@ -149,7 +149,7 @@ func (ss *MSSQL) Close() error {
}

// Lock creates an advisory local on the database to prevent multiple migrations from running at the same time.
func (ss *MSSQL) Lock() error {
func (ss *SQLServer) Lock() error {
if ss.isLocked {
return database.ErrLocked
}
Expand All @@ -162,7 +162,7 @@ func (ss *MSSQL) Lock() error {
// This will either obtain the lock immediately and return true,
// or return false if the lock cannot be acquired immediately.
// MS Docs: sp_getapplock: https://docs.microsoft.com/en-us/sql/relational-databases/system-stored-procedures/sp-getapplock-transact-sql?view=sql-server-2017
query := `EXEC sp_getapplock @Resource = ?, @LockMode = 'Update', @LockOwner = 'Session', @LockTimeout = 0`
query := `EXEC sp_getapplock @Resource = @p1, @LockMode = 'Update', @LockOwner = 'Session', @LockTimeout = 0`

var status mssql.ReturnStatus
if _, err = ss.conn.ExecContext(context.Background(), query, aid, &status); err == nil && status > -1 {
Expand All @@ -176,7 +176,7 @@ func (ss *MSSQL) Lock() error {
}

// Unlock froms the migration lock from the database
func (ss *MSSQL) Unlock() error {
func (ss *SQLServer) Unlock() error {
if !ss.isLocked {
return nil
}
Expand All @@ -187,7 +187,7 @@ func (ss *MSSQL) Unlock() error {
}

// MS Docs: sp_releaseapplock: https://docs.microsoft.com/en-us/sql/relational-databases/system-stored-procedures/sp-releaseapplock-transact-sql?view=sql-server-2017
query := `EXEC sp_releaseapplock @Resource = ?, @LockOwner = 'Session'`
query := `EXEC sp_releaseapplock @Resource = @p1, @LockOwner = 'Session'`
if _, err := ss.conn.ExecContext(context.Background(), query, aid); err != nil {
return &database.Error{OrigErr: err, Query: []byte(query)}
}
Expand All @@ -197,7 +197,7 @@ func (ss *MSSQL) Unlock() error {
}

// Run the migrations for the database
func (ss *MSSQL) Run(migration io.Reader) error {
func (ss *SQLServer) Run(migration io.Reader) error {
migr, err := ioutil.ReadAll(migration)
if err != nil {
return err
Expand All @@ -220,7 +220,7 @@ func (ss *MSSQL) Run(migration io.Reader) error {
}

// SetVersion for the current database
func (ss *MSSQL) SetVersion(version int, dirty bool) error {
func (ss *SQLServer) SetVersion(version int, dirty bool) error {

tx, err := ss.conn.BeginTx(context.Background(), &sql.TxOptions{})
if err != nil {
Expand All @@ -240,7 +240,7 @@ func (ss *MSSQL) SetVersion(version int, dirty bool) error {
if dirty {
dirtyBit = 1
}
query = `INSERT INTO "` + ss.config.MigrationsTable + `" (version, dirty) VALUES ($1, $2)`
query = `INSERT INTO "` + ss.config.MigrationsTable + `" (version, dirty) VALUES (@p1, @p2)`
if _, err := tx.Exec(query, version, dirtyBit); err != nil {
if errRollback := tx.Rollback(); errRollback != nil {
err = multierror.Append(err, errRollback)
Expand All @@ -257,7 +257,7 @@ func (ss *MSSQL) SetVersion(version int, dirty bool) error {
}

// Version of the current database state
func (ss *MSSQL) Version() (version int, dirty bool, err error) {
func (ss *SQLServer) Version() (version int, dirty bool, err error) {
query := `SELECT TOP 1 version, dirty FROM "` + ss.config.MigrationsTable + `"`
err = ss.conn.QueryRowContext(context.Background(), query).Scan(&version, &dirty)
switch {
Expand All @@ -274,7 +274,7 @@ func (ss *MSSQL) Version() (version int, dirty bool, err error) {
}

// Drop all tables from the database.
func (ss *MSSQL) Drop() error {
func (ss *SQLServer) Drop() error {

// drop all referential integrity constraints
query := `
Expand Down Expand Up @@ -308,7 +308,7 @@ func (ss *MSSQL) Drop() error {
return nil
}

func (ss *MSSQL) ensureVersionTable() (err error) {
func (ss *SQLServer) ensureVersionTable() (err error) {
if err = ss.Lock(); err != nil {
return err
}
Expand Down
16 changes: 8 additions & 8 deletions database/mssql/mssql_test.go → database/sqlserver/sqlserver_test.go
100755 → 100644
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package mssql
package sqlserver

import (
"context"
Expand Down Expand Up @@ -74,7 +74,7 @@ func Test(t *testing.T) {
}

addr := msConnectionString(ip, port)
p := &MSSQL{}
p := &SQLServer{}
d, err := p.Open(addr)
if err != nil {
t.Fatalf("%v", err)
Expand All @@ -98,7 +98,7 @@ func TestMigrate(t *testing.T) {
}

addr := msConnectionString(ip, port)
p := &MSSQL{}
p := &SQLServer{}
d, err := p.Open(addr)
if err != nil {
t.Fatalf("%v", err)
Expand Down Expand Up @@ -126,7 +126,7 @@ func TestMultiStatement(t *testing.T) {
}

addr := msConnectionString(ip, port)
ms := &MSSQL{}
ms := &SQLServer{}
d, err := ms.Open(addr)
if err != nil {
t.Fatal(err)
Expand All @@ -142,7 +142,7 @@ func TestMultiStatement(t *testing.T) {

// make sure second table exists
var exists int
if err := d.(*MSSQL).conn.QueryRowContext(context.Background(), "SELECT COUNT(1) FROM information_schema.tables WHERE table_name = 'bar' AND table_schema = (SELECT schema_name()) AND table_catalog = (SELECT db_name())").Scan(&exists); err != nil {
if err := d.(*SQLServer).conn.QueryRowContext(context.Background(), "SELECT COUNT(1) FROM information_schema.tables WHERE table_name = 'bar' AND table_schema = (SELECT schema_name()) AND table_catalog = (SELECT db_name())").Scan(&exists); err != nil {
t.Fatal(err)
}
if exists != 1 {
Expand All @@ -159,7 +159,7 @@ func TestErrorParsing(t *testing.T) {
}

addr := msConnectionString(ip, port)
p := &MSSQL{}
p := &SQLServer{}
d, err := p.Open(addr)
if err != nil {
t.Fatal(err)
Expand Down Expand Up @@ -189,14 +189,14 @@ func TestLockWorks(t *testing.T) {
}

addr := fmt.Sprintf("sqlserver://sa:%v@%v:%v?master", saPassword, ip, port)
p := &MSSQL{}
p := &SQLServer{}
d, err := p.Open(addr)
if err != nil {
t.Fatalf("%v", err)
}
dt.Test(t, d, []byte("SELECT 1"))

ms := d.(*MSSQL)
ms := d.(*SQLServer)

err = ms.Lock()
if err != nil {
Expand Down
7 changes: 0 additions & 7 deletions internal/cli/build_mssql.go

This file was deleted.

7 changes: 7 additions & 0 deletions internal/cli/build_sqlserver.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
// +build sqlserver

package cli

import (
_ "github.com/golang-migrate/migrate/v4/database/sqlserver"
)

0 comments on commit 8437fe6

Please sign in to comment.