Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Snowflake connector #623

Merged
merged 15 commits into from
Aug 27, 2024
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions packages/dbml-cli/src/cli/connector.js
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@ import logger from '../helpers/logger';

export default async function connectionHandler (program) {
try {
const { connection, format } = getConnectionOpt(program.args);
const { connection, databaseType } = getConnectionOpt(program.args);
const opts = program.opts();
const schemaJson = await connector.fetchSchemaJson(connection, format);
const schemaJson = await connector.fetchSchemaJson(connection, databaseType);

if (!opts.outFile && !opts.outDir) {
const res = importer.generateDbml(schemaJson);
Expand Down
12 changes: 7 additions & 5 deletions packages/dbml-cli/src/cli/index.js
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
/* eslint-disable max-len */
import program from 'commander';
import importHandler from './import';
import exportHandler from './export';
Expand Down Expand Up @@ -56,14 +57,15 @@ function db2dbml (args) {
// - postgres: postgresql://user:password@localhost:5432/dbname
// - mssql: 'Server=localhost,1433;Database=master;User Id=sa;Password=your_password;Encrypt=true;TrustServerCertificate
const description = `
<format> your database format (postgres, mysql, mssql)
<database-type> your database format (postgres, mysql, mssql, snowflake)
<connection-string> your database connection string:
- postgres: postgresql://user:password@localhost:5432/dbname
- mysql: mysql://user:password@localhost:3306/dbname
- mssql: 'Server=localhost,1433;Database=master;User Id=sa;Password=your_password;Encrypt=true;TrustServerCertificate=true;'
- postgres: 'postgresql://user:password@localhost:5432/dbname?schemas=schema1,schema2,schema3'
- mysql: 'mysql://user:password@localhost:3306/dbname'
- mssql: 'Server=localhost,1433;Database=master;User Id=sa;Password=your_password;Encrypt=true;TrustServerCertificate=true;Schemas=schema1,schema2,schema3;'
- snowflake: 'SERVER=<account_identifier>.<region>;UID=<your_username>;PWD=<your_password>;DATABASE=<your_database>;WAREHOUSE=<your_warehouse>;ROLE=<your_role>;SCHEMAS=schema1,schema2,schema3;'
`;
program
.usage('<format> <connection-string> [options]')
.usage('<database-type> <connection-string> [options]')
.description(description)
.option('-o, --out-file <pathspec>', 'compile all input files into a single files');

Expand Down
6 changes: 3 additions & 3 deletions packages/dbml-cli/src/cli/utils.js
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,14 @@ function getFormatOpt (opts) {
}

function getConnectionOpt (args) {
const supportedDatabases = ['postgres', 'mysql', 'mssql'];
const supportedDatabases = ['postgres', 'mysql', 'mssql', 'snowflake'];
const defaultConnectionOpt = {
connection: args[0],
format: 'unknown',
databaseType: 'unknown',
};

return reduce(args, (connectionOpt, arg) => {
if (supportedDatabases.includes(arg)) connectionOpt.format = arg;
if (supportedDatabases.includes(arg)) connectionOpt.databaseType = arg;
// Check if the arg is a connection string using regex
const connectionStringRegex = /^.*[:;]/;
if (connectionStringRegex.test(arg)) {
Expand Down
3 changes: 2 additions & 1 deletion packages/dbml-connector/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@
"dependencies": {
"mssql": "^11.0.1",
"mysql2": "^3.11.0",
"pg": "^8.12.0"
"pg": "^8.12.0",
"snowflake-sdk": "^1.12.0"
},
"engines": {
"node": ">=18"
Expand Down
9 changes: 6 additions & 3 deletions packages/dbml-connector/src/connectors/connector.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,20 @@ import { DatabaseSchema } from './types';
import { fetchSchemaJson as fetchPostgresSchemaJson } from './postgresConnector';
import { fetchSchemaJson as fetchMssqlSchemaJson } from './mssqlConnector';
import { fetchSchemaJson as fetchMysqlSchemaJson } from './mysqlConnector';
import { fetchSchemaJson as fetchSnowflakeSchemaJson } from './snowflakeConnector';

const fetchSchemaJson = async (connection: string, format: string): Promise<DatabaseSchema> => {
switch (format) {
const fetchSchemaJson = async (connection: string, databaseType: string): Promise<DatabaseSchema> => {
switch (databaseType) {
case 'postgres':
return fetchPostgresSchemaJson(connection);
case 'mssql':
return fetchMssqlSchemaJson(connection);
case 'mysql':
return fetchMysqlSchemaJson(connection);
case 'snowflake':
return fetchSnowflakeSchemaJson(connection);
default:
throw new Error(`Unsupported connection format: ${format}`);
throw new Error(`Unsupported database type: ${databaseType}`);
}
};

Expand Down
150 changes: 76 additions & 74 deletions packages/dbml-connector/src/connectors/mssqlConnector.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import {
DatabaseSchema,
DefaultInfo,
} from './types';
import { buildSchemaQuery, parseConnectionString } from '../utils/parseSchema';

const MSSQL_DATE_TYPES = [
'date',
Expand Down Expand Up @@ -132,83 +133,79 @@ const generateField = (row: Record<string, any>): Field => {
};
};

const generateTablesFieldsAndEnums = async (client: sql.ConnectionPool): Promise<{
const generateTablesFieldsAndEnums = async (client: sql.ConnectionPool, schemas: string[]): Promise<{
tables: Table[],
fields: FieldsDictionary,
enums: Enum[],
}> => {
const fields: FieldsDictionary = {};
const enums: Enum[] = [];
const tablesAndFieldsSql = `
WITH tables_and_fields AS (
SELECT
s.name AS table_schema,
t.name AS table_name,
c.name AS column_name,
ty.name AS data_type,
c.max_length AS character_maximum_length,
c.precision AS numeric_precision,
c.scale AS numeric_scale,
c.is_identity AS identity_increment,
CASE
WHEN c.is_nullable = 1 THEN 'YES'
ELSE 'NO'
END AS is_nullable,
CASE
WHEN c.default_object_id = 0 THEN NULL
ELSE OBJECT_DEFINITION(c.default_object_id)
END AS column_default,
-- Fetching table comments
p.value AS table_comment,
ep.value AS column_comment
FROM
sys.tables t
JOIN
sys.schemas s ON t.schema_id = s.schema_id
JOIN
sys.columns c ON t.object_id = c.object_id
JOIN
sys.types ty ON c.user_type_id = ty.user_type_id
LEFT JOIN
sys.extended_properties p ON p.major_id = t.object_id
AND p.name = 'MS_Description'
AND p.minor_id = 0 -- Ensure minor_id is 0 for table comments
LEFT JOIN
sys.extended_properties ep ON ep.major_id = c.object_id
AND ep.minor_id = c.column_id
AND ep.name = 'MS_Description'
WHERE
t.type = 'U' -- User-defined tables
)
WITH tables_and_fields AS (
SELECT
tf.table_schema,
tf.table_name,
tf.column_name,
tf.data_type,
tf.character_maximum_length,
tf.numeric_precision,
tf.numeric_scale,
tf.identity_increment,
tf.is_nullable,
tf.column_default,
tf.table_comment,
tf.column_comment,
cc.name AS check_constraint_name, -- Adding CHECK constraint name
cc.definition AS check_constraint_definition, -- Adding CHECK constraint definition
CASE
WHEN tf.column_default LIKE '((%))' THEN 'number'
WHEN tf.column_default LIKE '(''%'')' THEN 'string'
ELSE 'expression'
END AS default_type
s.name AS table_schema,
t.name AS table_name,
c.name AS column_name,
ty.name AS data_type,
c.max_length AS character_maximum_length,
c.precision AS numeric_precision,
c.scale AS numeric_scale,
c.is_identity AS identity_increment,
CASE
WHEN c.is_nullable = 1 THEN 'YES'
ELSE 'NO'
END AS is_nullable,
CASE
WHEN c.default_object_id = 0 THEN NULL
ELSE OBJECT_DEFINITION(c.default_object_id)
END AS column_default,
-- Fetching table comments
p.value AS table_comment,
ep.value AS column_comment
FROM
tables_and_fields AS tf
LEFT JOIN
sys.check_constraints cc ON cc.parent_object_id = OBJECT_ID(tf.table_schema + '.' + tf.table_name)
AND cc.definition LIKE '%' + tf.column_name + '%' -- Ensure the constraint references the column
ORDER BY
tf.table_schema,
tf.table_name,
tf.column_name;
sys.tables t
JOIN sys.schemas s ON t.schema_id = s.schema_id
JOIN sys.columns c ON t.object_id = c.object_id
JOIN sys.types ty ON c.user_type_id = ty.user_type_id
LEFT JOIN sys.extended_properties p ON p.major_id = t.object_id
AND p.name = 'MS_Description'
AND p.minor_id = 0 -- Ensure minor_id is 0 for table comments
LEFT JOIN sys.extended_properties ep ON ep.major_id = c.object_id
AND ep.minor_id = c.column_id
AND ep.name = 'MS_Description'
WHERE
t.type = 'U' -- User-defined tables
)
SELECT
tf.table_schema,
tf.table_name,
tf.column_name,
tf.data_type,
tf.character_maximum_length,
tf.numeric_precision,
tf.numeric_scale,
tf.identity_increment,
tf.is_nullable,
tf.column_default,
tf.table_comment,
tf.column_comment,
cc.name AS check_constraint_name, -- Adding CHECK constraint name
cc.definition AS check_constraint_definition, -- Adding CHECK constraint definition
CASE
WHEN tf.column_default LIKE '((%))' THEN 'number'
WHEN tf.column_default LIKE '(''%'')' THEN 'string'
ELSE 'expression'
END AS default_type
FROM
tables_and_fields AS tf
LEFT JOIN sys.check_constraints cc
ON cc.parent_object_id = OBJECT_ID(tf.table_schema + '.' + tf.table_name)
AND cc.definition LIKE '%' + tf.column_name + '%' -- Ensure the constraint references the column
${buildSchemaQuery('tf.table_schema', schemas, 'WHERE')}
ORDER BY
tf.table_schema,
tf.table_name,
tf.column_name;
`;

const tablesAndFieldsResult = await client.query(tablesAndFieldsSql);
Expand Down Expand Up @@ -259,7 +256,7 @@ const generateTablesFieldsAndEnums = async (client: sql.ConnectionPool): Promise
};
};

const generateRefs = async (client: sql.ConnectionPool): Promise<Ref[]> => {
const generateRefs = async (client: sql.ConnectionPool, schemas: string[]): Promise<Ref[]> => {
const refs: Ref[] = [];

const refsListSql = `
Expand Down Expand Up @@ -290,6 +287,7 @@ const generateRefs = async (client: sql.ConnectionPool): Promise<Ref[]> => {
JOIN sys.tables AS t2 ON fk.referenced_object_id = t2.object_id
JOIN sys.schemas AS s2 ON t2.schema_id = s2.schema_id
WHERE s.name NOT IN ('sys', 'information_schema')
${buildSchemaQuery('s.name', schemas)}
ORDER BY
s.name,
t.name;
Expand Down Expand Up @@ -334,7 +332,7 @@ const generateRefs = async (client: sql.ConnectionPool): Promise<Ref[]> => {
return refs;
};

const generateIndexes = async (client: sql.ConnectionPool) => {
const generateIndexes = async (client: sql.ConnectionPool, schemas: string[]) => {
const indexListSql = `
WITH user_tables AS (
SELECT
Expand All @@ -352,6 +350,7 @@ const generateIndexes = async (client: sql.ConnectionPool) => {
),
index_info AS (
SELECT
SCHEMA_NAME(t.schema_id) AS table_schema, -- Add schema information
OBJECT_NAME(i.object_id) AS table_name,
i.name AS index_name,
i.is_unique,
Expand Down Expand Up @@ -399,8 +398,10 @@ const generateIndexes = async (client: sql.ConnectionPool) => {
user_tables ut
LEFT JOIN
index_info ii ON ut.TABLE_NAME = ii.table_name
AND ut.TABLE_SCHEMA = ii.table_schema
WHERE
ii.columns IS NOT NULL
${buildSchemaQuery('ut.TABLE_SCHEMA', schemas)}
ORDER BY
ut.TABLE_NAME,
ii.constraint_type,
Expand Down Expand Up @@ -491,11 +492,12 @@ const generateIndexes = async (client: sql.ConnectionPool) => {
};

const fetchSchemaJson = async (connection: string): Promise<DatabaseSchema> => {
const client = await getValidatedClient(connection);
const { connectionString, schemas } = parseConnectionString(connection, 'odbc');
const client = await getValidatedClient(connectionString);

const tablesFieldsAndEnumsRes = generateTablesFieldsAndEnums(client);
const indexesRes = generateIndexes(client);
const refsRes = generateRefs(client);
const tablesFieldsAndEnumsRes = generateTablesFieldsAndEnums(client, schemas);
const indexesRes = generateIndexes(client, schemas);
const refsRes = generateRefs(client, schemas);

const res = await Promise.all([
tablesFieldsAndEnumsRes,
Expand Down
Loading
Loading