Skip to content

Commit

Permalink
✨ add support for AUTOINCREMENT
Browse files Browse the repository at this point in the history
  • Loading branch information
techouse committed Feb 10, 2024
1 parent 5369cf2 commit 8d3f70a
Show file tree
Hide file tree
Showing 8 changed files with 105 additions and 48 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
# 2.1.10

* [FEAT] add support for AUTOINCREMENT

# 2.1.9

* [FIX] pin MySQL Connector/Python to 8.3.0
Expand Down
3 changes: 2 additions & 1 deletion mysql_to_sqlite3/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Utility to transfer data from MySQL to SQLite 3."""
__version__ = "2.1.9"

__version__ = "2.1.10"

from .transporter import MySQLtoSQLite
1 change: 1 addition & 0 deletions mysql_to_sqlite3/cli.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""The command line interface of MySQLtoSQLite."""

import os
import sys
import typing as t
Expand Down
17 changes: 17 additions & 0 deletions mysql_to_sqlite3/sqlite_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,20 @@ def convert_date(value: t.Any) -> date:
return date.fromisoformat(value.decode())
except ValueError as err:
raise ValueError(f"DATE field contains {err}") # pylint: disable=W0707


Integer_Types: t.Set[str] = {
"INTEGER",
"INTEGER UNSIGNED",
"INT",
"INT UNSIGNED",
"BIGINT",
"BIGINT UNSIGNED",
"MEDIUMINT",
"MEDIUMINT UNSIGNED",
"SMALLINT",
"SMALLINT UNSIGNED",
"TINYINT",
"TINYINT UNSIGNED",
"NUMERIC",
}
85 changes: 61 additions & 24 deletions mysql_to_sqlite3/transporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from mysql_to_sqlite3.mysql_utils import CHARSET_INTRODUCERS
from mysql_to_sqlite3.sqlite_utils import (
CollatingSequences,
Integer_Types,
adapt_decimal,
adapt_timedelta,
convert_date,
Expand Down Expand Up @@ -384,24 +385,42 @@ def _build_create_table_sql(self, table_name: str) -> str:
column_type=row["Type"], # type: ignore[arg-type]
sqlite_json1_extension_enabled=self._sqlite_json1_extension_enabled,
)
sql += '\n\t"{name}" {type} {notnull} {default} {collation},'.format(
name=row["Field"].decode() if isinstance(row["Field"], bytes) else row["Field"],
type=column_type,
notnull="NULL" if row["Null"] == "YES" else "NOT NULL",
default=self._translate_default_from_mysql_to_sqlite(row["Default"], column_type, row["Extra"]),
collation=self._data_type_collation_sequence(self._collation, column_type),
)
if row["Key"] == "PRI" and row["Extra"] == "auto_increment":
if column_type in Integer_Types:
sql += '\n\t"{name}" INTEGER PRIMARY KEY AUTOINCREMENT,'.format(
name=row["Field"].decode() if isinstance(row["Field"], bytes) else row["Field"],
)
else:
self._logger.warning(
'Primary key "%s" in table "%s" is not an INTEGER type! Skipping.',
row["Field"],
table_name,
)
else:
sql += '\n\t"{name}" {type} {notnull} {default} {collation},'.format(
name=row["Field"].decode() if isinstance(row["Field"], bytes) else row["Field"],
type=column_type,
notnull="NULL" if row["Null"] == "YES" else "NOT NULL",
default=self._translate_default_from_mysql_to_sqlite(row["Default"], column_type, row["Extra"]),
collation=self._data_type_collation_sequence(self._collation, column_type),
)

self._mysql_cur_dict.execute(
"""
SELECT INDEX_NAME AS `name`,
IF (NON_UNIQUE = 0 AND INDEX_NAME = 'PRIMARY', 1, 0) AS `primary`,
IF (NON_UNIQUE = 0 AND INDEX_NAME <> 'PRIMARY', 1, 0) AS `unique`,
GROUP_CONCAT(COLUMN_NAME ORDER BY SEQ_IN_INDEX) AS `columns`
FROM information_schema.STATISTICS
WHERE TABLE_SCHEMA = %s
AND TABLE_NAME = %s
GROUP BY INDEX_NAME, NON_UNIQUE
SELECT s.INDEX_NAME AS `name`,
IF (NON_UNIQUE = 0 AND s.INDEX_NAME = 'PRIMARY', 1, 0) AS `primary`,
IF (NON_UNIQUE = 0 AND s.INDEX_NAME <> 'PRIMARY', 1, 0) AS `unique`,
IF (c.EXTRA = 'auto_increment', 1, 0) AS `auto_increment`,
GROUP_CONCAT(s.COLUMN_NAME ORDER BY SEQ_IN_INDEX) AS `columns`,
GROUP_CONCAT(c.COLUMN_TYPE ORDER BY SEQ_IN_INDEX) AS `types`
FROM information_schema.STATISTICS AS s
JOIN information_schema.COLUMNS AS c
ON s.TABLE_SCHEMA = c.TABLE_SCHEMA
AND s.TABLE_NAME = c.TABLE_NAME
AND s.COLUMN_NAME = c.COLUMN_NAME
WHERE s.TABLE_SCHEMA = %s
AND s.TABLE_NAME = %s
GROUP BY s.INDEX_NAME, s.NON_UNIQUE, c.EXTRA
""",
(self._mysql_database, table_name),
)
Expand Down Expand Up @@ -437,17 +456,33 @@ def _build_create_table_sql(self, table_name: str) -> str:
elif isinstance(index["columns"], str):
columns = index["columns"]

types: str = ""
if isinstance(index["types"], bytes):
types = index["types"].decode()
elif isinstance(index["types"], str):
types = index["types"]

if len(columns) > 0:
if index["primary"] in {1, "1"}:
primary += "\n\tPRIMARY KEY ({})".format(
", ".join(f'"{column}"' for column in columns.split(","))
)
if (index["auto_increment"] not in {1, "1"}) or any(
self._translate_type_from_mysql_to_sqlite(
column_type=_type,
sqlite_json1_extension_enabled=self._sqlite_json1_extension_enabled,
)
not in Integer_Types
for _type in types.split(",")
):
primary += "\n\tPRIMARY KEY ({})".format(
", ".join(f'"{column}"' for column in columns.split(","))
)
else:
indices += """CREATE {unique} INDEX IF NOT EXISTS "{name}" ON "{table}" ({columns});""".format(
unique="UNIQUE" if index["unique"] in {1, "1"} else "",
name=f"{table_name}_{index_name}"
if (table_collisions > 0 or self._prefix_indices)
else index_name,
name=(
f"{table_name}_{index_name}"
if (table_collisions > 0 or self._prefix_indices)
else index_name
),
table=table_name,
columns=", ".join(f'"{column}"' for column in columns.split(",")),
)
Expand Down Expand Up @@ -481,9 +516,11 @@ def _build_create_table_sql(self, table_name: str) -> str:
c.UPDATE_RULE,
c.DELETE_RULE
""".format(
JOIN="JOIN"
if (server_version is not None and server_version[0] == 8 and server_version[2] > 19)
else "LEFT JOIN"
JOIN=(
"JOIN"
if (server_version is not None and server_version[0] == 8 and server_version[2] > 19)
else "LEFT JOIN"
)
),
(self._mysql_database, table_name, "FOREIGN KEY"),
)
Expand Down
1 change: 1 addition & 0 deletions mysql_to_sqlite3/types.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Types for mysql-to-sqlite3."""

import os
import typing as t
from logging import Logger
Expand Down
40 changes: 18 additions & 22 deletions tests/func/mysql_to_sqlite3_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,14 +433,14 @@ def test_transfer_transfers_all_tables_from_mysql_to_sqlite(
mysql_inspect: Inspector = inspect(mysql_engine)
mysql_tables: t.List[str] = mysql_inspect.get_table_names()

mysql_connector_connection: t.Union[
PooledMySQLConnection, MySQLConnection, CMySQLConnection
] = mysql.connector.connect(
user=mysql_credentials.user,
password=mysql_credentials.password,
host=mysql_credentials.host,
port=mysql_credentials.port,
database=mysql_credentials.database,
mysql_connector_connection: t.Union[PooledMySQLConnection, MySQLConnection, CMySQLConnection] = (
mysql.connector.connect(
user=mysql_credentials.user,
password=mysql_credentials.password,
host=mysql_credentials.host,
port=mysql_credentials.port,
database=mysql_credentials.database,
)
)
server_version: t.Tuple[int, ...] = mysql_connector_connection.get_server_version()

Expand Down Expand Up @@ -490,9 +490,7 @@ def test_transfer_transfers_all_tables_from_mysql_to_sqlite(
AND i.CONSTRAINT_TYPE = :constraint_type
""".format(
# MySQL 8.0.19 still works with "LEFT JOIN" everything above requires "JOIN"
JOIN="JOIN"
if (server_version[0] == 8 and server_version[2] > 19)
else "LEFT JOIN"
JOIN="JOIN" if (server_version[0] == 8 and server_version[2] > 19) else "LEFT JOIN"
)
).bindparams(
table_schema=mysql_credentials.database,
Expand Down Expand Up @@ -1183,14 +1181,14 @@ def test_transfer_limited_rows_from_mysql_to_sqlite(
mysql_inspect: Inspector = inspect(mysql_engine)
mysql_tables: t.List[str] = mysql_inspect.get_table_names()

mysql_connector_connection: t.Union[
PooledMySQLConnection, MySQLConnection, CMySQLConnection
] = mysql.connector.connect(
user=mysql_credentials.user,
password=mysql_credentials.password,
host=mysql_credentials.host,
port=mysql_credentials.port,
database=mysql_credentials.database,
mysql_connector_connection: t.Union[PooledMySQLConnection, MySQLConnection, CMySQLConnection] = (
mysql.connector.connect(
user=mysql_credentials.user,
password=mysql_credentials.password,
host=mysql_credentials.host,
port=mysql_credentials.port,
database=mysql_credentials.database,
)
)
server_version: t.Tuple[int, ...] = mysql_connector_connection.get_server_version()

Expand Down Expand Up @@ -1240,9 +1238,7 @@ def test_transfer_limited_rows_from_mysql_to_sqlite(
AND i.CONSTRAINT_TYPE = :constraint_type
""".format(
# MySQL 8.0.19 still works with "LEFT JOIN" everything above requires "JOIN"
JOIN="JOIN"
if (server_version[0] == 8 and server_version[2] > 19)
else "LEFT JOIN"
JOIN="JOIN" if (server_version[0] == 8 and server_version[2] > 19) else "LEFT JOIN"
)
).bindparams(
table_schema=mysql_credentials.database,
Expand Down
2 changes: 1 addition & 1 deletion tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ deps =
mypy>=1.3.0
-rrequirements_dev.txt
commands =
mypy mysql_to_sqlite3 --enable-incomplete-feature=Unpack
mypy mysql_to_sqlite3

[testenv:linters]
basepython = python3
Expand Down

0 comments on commit 8d3f70a

Please sign in to comment.