diff --git a/pkg/analyzer/analyzer.go b/pkg/analyzer/analyzer.go index 5309635..1f063b4 100644 --- a/pkg/analyzer/analyzer.go +++ b/pkg/analyzer/analyzer.go @@ -14,14 +14,15 @@ import ( ) const ( - TimeWeekdayFlag = "time-weekday" - TimeMonthFlag = "time-month" - TimeLayoutFlag = "time-layout" - CryptoHashFlag = "crypto-hash" - HTTPMethodFlag = "http-method" - HTTPStatusCodeFlag = "http-status-code" - RPCDefaultPathFlag = "rpc-default-path" - OSDevNullFlag = "os-dev-null" + TimeWeekdayFlag = "time-weekday" + TimeMonthFlag = "time-month" + TimeLayoutFlag = "time-layout" + CryptoHashFlag = "crypto-hash" + HTTPMethodFlag = "http-method" + HTTPStatusCodeFlag = "http-status-code" + RPCDefaultPathFlag = "rpc-default-path" + OSDevNullFlag = "os-dev-null" + SQLIsolationLevelFlag = "sql-isolation-level" ) // New returns new usestdlibvars analyzer. @@ -45,6 +46,7 @@ func flags() flag.FlagSet { flags.Bool(CryptoHashFlag, false, "suggest the use of crypto.Hash") flags.Bool(RPCDefaultPathFlag, false, "suggest the use of rpc.DefaultXXPath") flags.Bool(OSDevNullFlag, false, "suggest the use of os.DevNull") + flags.Bool(SQLIsolationLevelFlag, false, "suggest the use of sql.LevelXX") return *flags } @@ -99,6 +101,10 @@ func run(pass *analysis.Pass) (interface{}, error) { checkOSDevNull(pass, n) } + if lookupFlag(pass, SQLIsolationLevelFlag) { + checkSQLIsolationLevel(pass, n) + } + case *ast.CompositeLit: typ, ok := n.Type.(*ast.SelectorExpr) if !ok { @@ -403,6 +409,14 @@ func checkOSDevNull(pass *analysis.Pass, basicLit *ast.BasicLit) { } } +func checkSQLIsolationLevel(pass *analysis.Pass, basicLit *ast.BasicLit) { + currentVal := getBasicLitValue(basicLit) + + if newVal, ok := mapping.SQLIsolationLevel[currentVal]; ok { + report(pass, basicLit.Pos(), currentVal, newVal) + } +} + // getBasicLitFromArgs gets the *ast.BasicLit of a function argument. // // Arguments: diff --git a/pkg/analyzer/analyzer_test.go b/pkg/analyzer/analyzer_test.go index 01f8fa2..9c7dd08 100644 --- a/pkg/analyzer/analyzer_test.go +++ b/pkg/analyzer/analyzer_test.go @@ -15,6 +15,7 @@ func TestUseStdlibVars(t *testing.T) { "a/rpc", "a/time", "a/os", + "a/sql", } a := analyzer.New() @@ -37,6 +38,9 @@ func TestUseStdlibVars(t *testing.T) { if err := a.Flags.Set(analyzer.OSDevNullFlag, "true"); err != nil { t.Error(err) } + if err := a.Flags.Set(analyzer.SQLIsolationLevelFlag, "true"); err != nil { + t.Error(err) + } analysistest.Run(t, analysistest.TestData(), a, pkgs...) } diff --git a/pkg/analyzer/internal/gen.go b/pkg/analyzer/internal/gen.go index 743f03a..7f343bc 100644 --- a/pkg/analyzer/internal/gen.go +++ b/pkg/analyzer/internal/gen.go @@ -89,6 +89,12 @@ func main() { templateName: "test-template.go.tmpl", fileName: "pkg/analyzer/testdata/src/a/os/devnull.go", }, + { + mapping: mapping.SQLIsolationLevel, + packageName: "sql_test", + templateName: "test-template.go.tmpl", + fileName: "pkg/analyzer/testdata/src/a/sql/isolationlevel.go", + }, } for _, operation := range operations { diff --git a/pkg/analyzer/internal/mapping/mapping.go b/pkg/analyzer/internal/mapping/mapping.go index af89d50..c0049cf 100644 --- a/pkg/analyzer/internal/mapping/mapping.go +++ b/pkg/analyzer/internal/mapping/mapping.go @@ -2,6 +2,7 @@ package mapping import ( "crypto" + "database/sql" "net/http" "net/rpc" "os" @@ -164,3 +165,14 @@ var TimeLayout = map[string]string{ var OSDevNull = map[string]string{ os.DevNull: "os.DevNull", } + +var SQLIsolationLevel = map[string]string{ + // sql.LevelDefault.String(): "sql.LevelDefault.String()", + sql.LevelReadUncommitted.String(): "sql.LevelReadUncommitted.String()", + sql.LevelReadCommitted.String(): "sql.LevelReadCommitted.String()", + sql.LevelWriteCommitted.String(): "sql.LevelWriteCommitted.String()", + sql.LevelRepeatableRead.String(): "sql.LevelRepeatableRead.String()", + // sql.LevelSnapshot.String(): "sql.LevelSnapshot.String()", + // sql.LevelSerializable.String(): "sql.LevelSerializable.String()", + // sql.LevelLinearizable.String(): "sql.LevelLinearizable.String()", +} diff --git a/pkg/analyzer/testdata/src/a/sql/isolationlevel.go b/pkg/analyzer/testdata/src/a/sql/isolationlevel.go new file mode 100755 index 0000000..ccc5422 --- /dev/null +++ b/pkg/analyzer/testdata/src/a/sql/isolationlevel.go @@ -0,0 +1,49 @@ +// Code generated by usestdlibvars, DO NOT EDIT. + +package sql_test + +import "fmt" + +var ( + _ = "Read Committed" // want `"Read Committed" can be replaced by sql\.LevelReadCommitted\.String\(\)` + _ = "Read Uncommitted" // want `"Read Uncommitted" can be replaced by sql\.LevelReadUncommitted\.String\(\)` + _ = "Repeatable Read" // want `"Repeatable Read" can be replaced by sql\.LevelRepeatableRead\.String\(\)` + _ = "Write Committed" // want `"Write Committed" can be replaced by sql\.LevelWriteCommitted\.String\(\)` +) + +const ( + _ = "Read Committed" // want `"Read Committed" can be replaced by sql\.LevelReadCommitted\.String\(\)` + _ = "Read Uncommitted" // want `"Read Uncommitted" can be replaced by sql\.LevelReadUncommitted\.String\(\)` + _ = "Repeatable Read" // want `"Repeatable Read" can be replaced by sql\.LevelRepeatableRead\.String\(\)` + _ = "Write Committed" // want `"Write Committed" can be replaced by sql\.LevelWriteCommitted\.String\(\)` +) + +func _() { + _ = func(s string) string { return s }("Read Committed") // want `"Read Committed" can be replaced by sql\.LevelReadCommitted\.String\(\)` + _ = func(s string) string { return s }("text before key Read Committed") + _ = func(s string) string { return s }("Read Committed text after key") + _ = func(s string) string { return s }("Read Uncommitted") // want `"Read Uncommitted" can be replaced by sql\.LevelReadUncommitted\.String\(\)` + _ = func(s string) string { return s }("text before key Read Uncommitted") + _ = func(s string) string { return s }("Read Uncommitted text after key") + _ = func(s string) string { return s }("Repeatable Read") // want `"Repeatable Read" can be replaced by sql\.LevelRepeatableRead\.String\(\)` + _ = func(s string) string { return s }("text before key Repeatable Read") + _ = func(s string) string { return s }("Repeatable Read text after key") + _ = func(s string) string { return s }("Write Committed") // want `"Write Committed" can be replaced by sql\.LevelWriteCommitted\.String\(\)` + _ = func(s string) string { return s }("text before key Write Committed") + _ = func(s string) string { return s }("Write Committed text after key") +} + +func _() { + _ = fmt.Sprint("Read Committed") // want `"Read Committed" can be replaced by sql\.LevelReadCommitted\.String\(\)` + _ = fmt.Sprint("text before key Read Committed") + _ = fmt.Sprint("Read Committed text after key") + _ = fmt.Sprint("Read Uncommitted") // want `"Read Uncommitted" can be replaced by sql\.LevelReadUncommitted\.String\(\)` + _ = fmt.Sprint("text before key Read Uncommitted") + _ = fmt.Sprint("Read Uncommitted text after key") + _ = fmt.Sprint("Repeatable Read") // want `"Repeatable Read" can be replaced by sql\.LevelRepeatableRead\.String\(\)` + _ = fmt.Sprint("text before key Repeatable Read") + _ = fmt.Sprint("Repeatable Read text after key") + _ = fmt.Sprint("Write Committed") // want `"Write Committed" can be replaced by sql\.LevelWriteCommitted\.String\(\)` + _ = fmt.Sprint("text before key Write Committed") + _ = fmt.Sprint("Write Committed text after key") +}