From 0b118d674dc67155403b08390ab4e6d46aa9ca87 Mon Sep 17 00:00:00 2001
From: Sasha Melentyev <sasha@melentyev.io>
Date: Thu, 1 Sep 2022 14:19:47 +0300
Subject: [PATCH] feat: add sql.IsolationLevel

---
 pkg/analyzer/analyzer.go                      | 30 +++++++++---
 pkg/analyzer/analyzer_test.go                 |  4 ++
 pkg/analyzer/internal/gen.go                  |  6 +++
 pkg/analyzer/internal/mapping/mapping.go      | 12 +++++
 .../testdata/src/a/sql/isolationlevel.go      | 49 +++++++++++++++++++
 5 files changed, 93 insertions(+), 8 deletions(-)
 create mode 100755 pkg/analyzer/testdata/src/a/sql/isolationlevel.go

diff --git a/pkg/analyzer/analyzer.go b/pkg/analyzer/analyzer.go
index 5309635..1f063b4 100644
--- a/pkg/analyzer/analyzer.go
+++ b/pkg/analyzer/analyzer.go
@@ -14,14 +14,15 @@ import (
 )
 
 const (
-	TimeWeekdayFlag    = "time-weekday"
-	TimeMonthFlag      = "time-month"
-	TimeLayoutFlag     = "time-layout"
-	CryptoHashFlag     = "crypto-hash"
-	HTTPMethodFlag     = "http-method"
-	HTTPStatusCodeFlag = "http-status-code"
-	RPCDefaultPathFlag = "rpc-default-path"
-	OSDevNullFlag      = "os-dev-null"
+	TimeWeekdayFlag       = "time-weekday"
+	TimeMonthFlag         = "time-month"
+	TimeLayoutFlag        = "time-layout"
+	CryptoHashFlag        = "crypto-hash"
+	HTTPMethodFlag        = "http-method"
+	HTTPStatusCodeFlag    = "http-status-code"
+	RPCDefaultPathFlag    = "rpc-default-path"
+	OSDevNullFlag         = "os-dev-null"
+	SQLIsolationLevelFlag = "sql-isolation-level"
 )
 
 // New returns new usestdlibvars analyzer.
@@ -45,6 +46,7 @@ func flags() flag.FlagSet {
 	flags.Bool(CryptoHashFlag, false, "suggest the use of crypto.Hash")
 	flags.Bool(RPCDefaultPathFlag, false, "suggest the use of rpc.DefaultXXPath")
 	flags.Bool(OSDevNullFlag, false, "suggest the use of os.DevNull")
+	flags.Bool(SQLIsolationLevelFlag, false, "suggest the use of sql.LevelXX")
 	return *flags
 }
 
@@ -99,6 +101,10 @@ func run(pass *analysis.Pass) (interface{}, error) {
 				checkOSDevNull(pass, n)
 			}
 
+			if lookupFlag(pass, SQLIsolationLevelFlag) {
+				checkSQLIsolationLevel(pass, n)
+			}
+
 		case *ast.CompositeLit:
 			typ, ok := n.Type.(*ast.SelectorExpr)
 			if !ok {
@@ -403,6 +409,14 @@ func checkOSDevNull(pass *analysis.Pass, basicLit *ast.BasicLit) {
 	}
 }
 
+func checkSQLIsolationLevel(pass *analysis.Pass, basicLit *ast.BasicLit) {
+	currentVal := getBasicLitValue(basicLit)
+
+	if newVal, ok := mapping.SQLIsolationLevel[currentVal]; ok {
+		report(pass, basicLit.Pos(), currentVal, newVal)
+	}
+}
+
 // getBasicLitFromArgs gets the *ast.BasicLit of a function argument.
 //
 // Arguments:
diff --git a/pkg/analyzer/analyzer_test.go b/pkg/analyzer/analyzer_test.go
index 01f8fa2..9c7dd08 100644
--- a/pkg/analyzer/analyzer_test.go
+++ b/pkg/analyzer/analyzer_test.go
@@ -15,6 +15,7 @@ func TestUseStdlibVars(t *testing.T) {
 		"a/rpc",
 		"a/time",
 		"a/os",
+		"a/sql",
 	}
 
 	a := analyzer.New()
@@ -37,6 +38,9 @@ func TestUseStdlibVars(t *testing.T) {
 	if err := a.Flags.Set(analyzer.OSDevNullFlag, "true"); err != nil {
 		t.Error(err)
 	}
+	if err := a.Flags.Set(analyzer.SQLIsolationLevelFlag, "true"); err != nil {
+		t.Error(err)
+	}
 
 	analysistest.Run(t, analysistest.TestData(), a, pkgs...)
 }
diff --git a/pkg/analyzer/internal/gen.go b/pkg/analyzer/internal/gen.go
index 743f03a..7f343bc 100644
--- a/pkg/analyzer/internal/gen.go
+++ b/pkg/analyzer/internal/gen.go
@@ -89,6 +89,12 @@ func main() {
 			templateName: "test-template.go.tmpl",
 			fileName:     "pkg/analyzer/testdata/src/a/os/devnull.go",
 		},
+		{
+			mapping:      mapping.SQLIsolationLevel,
+			packageName:  "sql_test",
+			templateName: "test-template.go.tmpl",
+			fileName:     "pkg/analyzer/testdata/src/a/sql/isolationlevel.go",
+		},
 	}
 
 	for _, operation := range operations {
diff --git a/pkg/analyzer/internal/mapping/mapping.go b/pkg/analyzer/internal/mapping/mapping.go
index af89d50..c0049cf 100644
--- a/pkg/analyzer/internal/mapping/mapping.go
+++ b/pkg/analyzer/internal/mapping/mapping.go
@@ -2,6 +2,7 @@ package mapping
 
 import (
 	"crypto"
+	"database/sql"
 	"net/http"
 	"net/rpc"
 	"os"
@@ -164,3 +165,14 @@ var TimeLayout = map[string]string{
 var OSDevNull = map[string]string{
 	os.DevNull: "os.DevNull",
 }
+
+var SQLIsolationLevel = map[string]string{
+	// sql.LevelDefault.String():         "sql.LevelDefault.String()",
+	sql.LevelReadUncommitted.String(): "sql.LevelReadUncommitted.String()",
+	sql.LevelReadCommitted.String():   "sql.LevelReadCommitted.String()",
+	sql.LevelWriteCommitted.String():  "sql.LevelWriteCommitted.String()",
+	sql.LevelRepeatableRead.String():  "sql.LevelRepeatableRead.String()",
+	// sql.LevelSnapshot.String():        "sql.LevelSnapshot.String()",
+	// sql.LevelSerializable.String():    "sql.LevelSerializable.String()",
+	// sql.LevelLinearizable.String():    "sql.LevelLinearizable.String()",
+}
diff --git a/pkg/analyzer/testdata/src/a/sql/isolationlevel.go b/pkg/analyzer/testdata/src/a/sql/isolationlevel.go
new file mode 100755
index 0000000..ccc5422
--- /dev/null
+++ b/pkg/analyzer/testdata/src/a/sql/isolationlevel.go
@@ -0,0 +1,49 @@
+// Code generated by usestdlibvars, DO NOT EDIT.
+
+package sql_test
+
+import "fmt"
+
+var (
+	_ = "Read Committed"   // want `"Read Committed" can be replaced by sql\.LevelReadCommitted\.String\(\)`
+	_ = "Read Uncommitted" // want `"Read Uncommitted" can be replaced by sql\.LevelReadUncommitted\.String\(\)`
+	_ = "Repeatable Read"  // want `"Repeatable Read" can be replaced by sql\.LevelRepeatableRead\.String\(\)`
+	_ = "Write Committed"  // want `"Write Committed" can be replaced by sql\.LevelWriteCommitted\.String\(\)`
+)
+
+const (
+	_ = "Read Committed"   // want `"Read Committed" can be replaced by sql\.LevelReadCommitted\.String\(\)`
+	_ = "Read Uncommitted" // want `"Read Uncommitted" can be replaced by sql\.LevelReadUncommitted\.String\(\)`
+	_ = "Repeatable Read"  // want `"Repeatable Read" can be replaced by sql\.LevelRepeatableRead\.String\(\)`
+	_ = "Write Committed"  // want `"Write Committed" can be replaced by sql\.LevelWriteCommitted\.String\(\)`
+)
+
+func _() {
+	_ = func(s string) string { return s }("Read Committed") // want `"Read Committed" can be replaced by sql\.LevelReadCommitted\.String\(\)`
+	_ = func(s string) string { return s }("text before key Read Committed")
+	_ = func(s string) string { return s }("Read Committed text after key")
+	_ = func(s string) string { return s }("Read Uncommitted") // want `"Read Uncommitted" can be replaced by sql\.LevelReadUncommitted\.String\(\)`
+	_ = func(s string) string { return s }("text before key Read Uncommitted")
+	_ = func(s string) string { return s }("Read Uncommitted text after key")
+	_ = func(s string) string { return s }("Repeatable Read") // want `"Repeatable Read" can be replaced by sql\.LevelRepeatableRead\.String\(\)`
+	_ = func(s string) string { return s }("text before key Repeatable Read")
+	_ = func(s string) string { return s }("Repeatable Read text after key")
+	_ = func(s string) string { return s }("Write Committed") // want `"Write Committed" can be replaced by sql\.LevelWriteCommitted\.String\(\)`
+	_ = func(s string) string { return s }("text before key Write Committed")
+	_ = func(s string) string { return s }("Write Committed text after key")
+}
+
+func _() {
+	_ = fmt.Sprint("Read Committed") // want `"Read Committed" can be replaced by sql\.LevelReadCommitted\.String\(\)`
+	_ = fmt.Sprint("text before key Read Committed")
+	_ = fmt.Sprint("Read Committed text after key")
+	_ = fmt.Sprint("Read Uncommitted") // want `"Read Uncommitted" can be replaced by sql\.LevelReadUncommitted\.String\(\)`
+	_ = fmt.Sprint("text before key Read Uncommitted")
+	_ = fmt.Sprint("Read Uncommitted text after key")
+	_ = fmt.Sprint("Repeatable Read") // want `"Repeatable Read" can be replaced by sql\.LevelRepeatableRead\.String\(\)`
+	_ = fmt.Sprint("text before key Repeatable Read")
+	_ = fmt.Sprint("Repeatable Read text after key")
+	_ = fmt.Sprint("Write Committed") // want `"Write Committed" can be replaced by sql\.LevelWriteCommitted\.String\(\)`
+	_ = fmt.Sprint("text before key Write Committed")
+	_ = fmt.Sprint("Write Committed text after key")
+}