diff --git a/pkg/db/mysql.go b/pkg/db/mysql.go index 4bfb37f..3327013 100644 --- a/pkg/db/mysql.go +++ b/pkg/db/mysql.go @@ -3,6 +3,7 @@ package db import ( "context" "database/sql" + "errors" "github.com/DanielLiu1123/gencoder/pkg/model" "slices" "sort" @@ -22,6 +23,9 @@ where table_schema = ? tableRow := db.QueryRowContext(ctx, tableSql, schema, table) var t model.Table if err := tableRow.Scan(&t.Schema, &t.Name, &t.Comment); err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, nil + } return nil, err } @@ -45,7 +49,7 @@ order by ordinal_position; } defer columnRows.Close() - var columns []*model.Column + columns := make([]*model.Column, 0) for columnRows.Next() { var col model.Column if err := columnRows.Scan(&col.Ordinal, &col.Name, &col.Type, &col.IsNullable, &col.DefaultValue, &col.IsPrimaryKey, &col.Comment); err != nil { @@ -103,7 +107,7 @@ order by index_name, seq_in_index; }) } - var indexes []*model.Index + indexes := make([]*model.Index, 0) for _, index := range indexMap { indexes = append(indexes, index) } diff --git a/pkg/db/mysql_test.go b/pkg/db/mysql_test.go index 8c805a4..10d65d5 100644 --- a/pkg/db/mysql_test.go +++ b/pkg/db/mysql_test.go @@ -2,8 +2,8 @@ package db import ( "context" - "database/sql" "github.com/stretchr/testify/assert" + "github.com/xo/dburl" "os/exec" "testing" "time" @@ -56,17 +56,10 @@ func TestGenMySQLTable(t *testing.T) { // Wait for MySQL to initialize time.Sleep(10 * time.Second) - dsn := "root:root@tcp(127.0.0.1:3306)/testdb" - db, err := sql.Open("mysql", dsn) + db, err := dburl.Open("mysql://root:root@localhost:3306/testdb") if err != nil { - t.Fatalf("Failed to connect to the database: %s", err) + t.Fatalf("Failed to open database connection: %s", err) } - defer func(db *sql.DB) { - err := db.Close() - if err != nil { - t.Fatalf("Failed to close the database: %s", err) - } - }(db) _, err = db.Exec(`CREATE TABLE testdb.user ( id INT AUTO_INCREMENT PRIMARY KEY, diff --git a/pkg/db/postgres.go b/pkg/db/postgres.go index 4b86cf1..5d16843 100644 --- a/pkg/db/postgres.go +++ b/pkg/db/postgres.go @@ -3,6 +3,7 @@ package db import ( "context" "database/sql" + "errors" "github.com/DanielLiu1123/gencoder/pkg/model" "slices" "sort" @@ -23,6 +24,9 @@ WHERE table_schema = $1 tableRow := db.QueryRowContext(ctx, tableSql, schema, name) var t model.Table if err := tableRow.Scan(&t.Schema, &t.Name, &t.Comment); err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, nil + } return nil, err } @@ -53,7 +57,7 @@ ORDER BY a.attnum; } defer columnRows.Close() - var columns []*model.Column + columns := make([]*model.Column, 0) for columnRows.Next() { var col model.Column if err := columnRows.Scan(&col.Ordinal, &col.Name, &col.Type, &col.IsNullable, &col.DefaultValue, &col.IsPrimaryKey, &col.Comment); err != nil { @@ -118,7 +122,7 @@ ORDER BY ic.relname, indkey_col.ordinality; }) } - var indexes []*model.Index + indexes := make([]*model.Index, 0) for _, index := range indexMap { indexes = append(indexes, index) } diff --git a/pkg/util/util.go b/pkg/util/util.go index a4c8bda..d14b284 100644 --- a/pkg/util/util.go +++ b/pkg/util/util.go @@ -143,6 +143,12 @@ func collectRenderContextsForDBConfig(dbCfg *model.DatabaseConfig) []*model.Rend log.Fatal(err) } + // table not found + if table == nil { + log.Printf("table %s.%s not found, skipping", schema, tbCfg.Name) + continue + } + ctx := createRenderContext(dbCfg, tbCfg, table) contexts = append(contexts, ctx)