diff --git a/cmd/mangosql/actions/codegen/codegen.go b/cmd/mangosql/actions/codegen/codegen.go index 9c2e309..c696504 100644 --- a/cmd/mangosql/actions/codegen/codegen.go +++ b/cmd/mangosql/actions/codegen/codegen.go @@ -64,6 +64,7 @@ func generate(opts generateOptions) error { // parse schema schema, err := internal.ParseSchema(sql) if err != nil { + fmt.Printf("schema parsing error: %+v\n", err) return err } diff --git a/cmd/mangosql/actions/diagram/diagram.go b/cmd/mangosql/actions/diagram/diagram.go index d604e63..e8100c9 100644 --- a/cmd/mangosql/actions/diagram/diagram.go +++ b/cmd/mangosql/actions/diagram/diagram.go @@ -103,6 +103,7 @@ func diagram(opts diagramOptions) error { // parse schema schema, err := internal.ParseSchema(sql) if err != nil { + fmt.Printf("schema parsing error: %+v\n", err) return err } diff --git a/internal/preparser.go b/internal/preparser.go index ddf5fe8..d75bd9d 100644 --- a/internal/preparser.go +++ b/internal/preparser.go @@ -98,7 +98,7 @@ func removeComments(sql string) string { } func replaceMysqlDateTypes(sql string) string { - regCond := regexp.MustCompile(`(?i) (datetime|date|timestamp)(\(\d*\))?`) + regCond := regexp.MustCompile(`(?i)\s(datetime|date|timestamp)(\(\d*\))?`) matches := regCond.FindAllStringSubmatchIndex(sql, -1) slices.Reverse(matches) for _, match := range matches { @@ -111,7 +111,7 @@ func replaceMysqlDateTypes(sql string) string { func replaceMysqlIntTypes(sql string) string { sql = strings.ReplaceAll(sql, " unsigned ", " ") - regCond := regexp.MustCompile(`(?i) (integer|smallint|tinyint|bigint|int)(\(\d*\))?`) + regCond := regexp.MustCompile(`(?i)\s(integer|smallint|tinyint|bigint|int)(\(\d*\))?`) matches := regCond.FindAllStringSubmatchIndex(sql, -1) slices.Reverse(matches) for _, match := range matches { @@ -122,7 +122,7 @@ func replaceMysqlIntTypes(sql string) string { } func replaceMysqlFloatTypes(sql string) string { - regCond := regexp.MustCompile(`(?i) (double|float)(\(\d*\))?`) + regCond := regexp.MustCompile(`(?i)\s(double precision|double|float)(\(\d*\))?`) matches := regCond.FindAllStringSubmatchIndex(sql, -1) slices.Reverse(matches) for _, match := range matches { @@ -133,18 +133,25 @@ func replaceMysqlFloatTypes(sql string) string { } func replaceMysqlTextTypes(sql string) string { - regCond := regexp.MustCompile(`(?i) (enum|mediumtext|longtext|tinytext|nvarchar|character varying|char|set)(\(.*?\))?`) + regCond := regexp.MustCompile(`(?i)\s(mediumtext|longtext|tinytext|nvarchar|character varying|character|char)(\(.*?\))?`) matches := regCond.FindAllStringSubmatchIndex(sql, -1) slices.Reverse(matches) for _, match := range matches { sql = sql[:match[0]] + " text" + sql[match[1]:] } + regEnums := regexp.MustCompile(`(?i)\s(enum|set)(\(.*?\))`) + matchesEnums := regEnums.FindAllStringSubmatchIndex(sql, -1) + slices.Reverse(matchesEnums) + for _, match := range matchesEnums { + sql = sql[:match[0]] + " text" + sql[match[1]:] + } + return sql } func replaceMysqlDataTypes(sql string) string { - regCond := regexp.MustCompile(`(?i) (binary|longblob|mediumblob|tinyblob|blob|tsvector)(\(\d*\))?`) + regCond := regexp.MustCompile(`(?i)\s(binary|longblob|mediumblob|tinyblob|blob|tsvector)(\(\d*\))?`) matches := regCond.FindAllStringSubmatchIndex(sql, -1) slices.Reverse(matches) for _, match := range matches { @@ -155,7 +162,7 @@ func replaceMysqlDataTypes(sql string) string { } func replaceMysqlUpdate(sql string) string { - regCond := regexp.MustCompile(`(?i) ON (DELETE|UPDATE) SET .*`) + regCond := regexp.MustCompile(`(?i)\sON\s(DELETE|UPDATE)\sSET\s\w*`) matches := regCond.FindAllStringSubmatchIndex(sql, -1) slices.Reverse(matches) for _, match := range matches { @@ -166,13 +173,20 @@ func replaceMysqlUpdate(sql string) string { } func replaceMysqlChartset(sql string) string { - regCond := regexp.MustCompile(`(?i) ((CHARACTER SET|COLLATE|DEFAULT CHARSET|CHARTSET|ENGINE|AUTO_INCREMENT|COMMENT)[ =]('.*?'|.*))[,; ]`) + regCond := regexp.MustCompile(`(?i)\s((CHARACTER SET|COLLATE|DEFAULT CHARSET|CHARTSET|ENGINE|AUTO_INCREMENT)[ =]('.*'|.*))[,; ]`) matches := regCond.FindAllStringSubmatchIndex(sql, -1) slices.Reverse(matches) for _, match := range matches { sql = sql[:match[2]] + " " + sql[match[3]:] } + regComment := regexp.MustCompile(`(?i)\sCOMMENT\s'.*'`) + matchesComment := regComment.FindAllStringSubmatchIndex(sql, -1) + slices.Reverse(matchesComment) + for _, match := range matchesComment { + sql = sql[:match[0]] + " " + sql[match[1]:] + } + return sql } @@ -250,11 +264,11 @@ func filterValidOperations(sql string) string { for _, res := range res { txt := sql[res[0]:res[1]] - if strings.Contains(txt, "OWNER TO") { + if strings.Contains(strings.ToLower(txt), "owner to") { continue } - if strings.Contains(txt, "to_tsvector") { + if strings.Contains(strings.ToLower(txt), "to_tsvector") { continue } diff --git a/tests/parser/parser_test.go b/tests/parser/parser_test.go index 9ccd9f2..ba0b2f3 100644 --- a/tests/parser/parser_test.go +++ b/tests/parser/parser_test.go @@ -6,7 +6,7 @@ import ( "testing" "github.com/kefniark/mango-sql/internal" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestMysql(t *testing.T) { @@ -20,12 +20,12 @@ func TestMysql(t *testing.T) { t.Run(entry.Name(), func(t *testing.T) { data, err := os.ReadFile(path.Join(folder, entry.Name(), "schema.sql")) if err != nil { - assert.NoError(t, err) + require.NoError(t, err) } _, err = internal.ParseSchema(string(data)) if err != nil { - assert.NoError(t, err) + require.NoError(t, err) } }) } @@ -42,12 +42,12 @@ func TestPostgres(t *testing.T) { t.Run(entry.Name(), func(t *testing.T) { data, err := os.ReadFile(path.Join(folder, entry.Name(), "schema.sql")) if err != nil { - assert.NoError(t, err) + require.NoError(t, err) } _, err = internal.ParseSchema(string(data)) if err != nil { - assert.NoError(t, err) + require.NoError(t, err) } }) }