diff --git a/docs/_toc.yml b/docs/_toc.yml index 9c55fc886c..b049af511c 100644 --- a/docs/_toc.yml +++ b/docs/_toc.yml @@ -60,6 +60,7 @@ parts: - file: source/reference/databases/postgres - file: source/reference/databases/sqlite - file: source/reference/databases/mysql + - file: source/reference/databases/mariadb - file: source/reference/ai/index title: AI Engines diff --git a/docs/source/reference/databases/mariadb.rst b/docs/source/reference/databases/mariadb.rst new file mode 100644 index 0000000000..fdbb443bb6 --- /dev/null +++ b/docs/source/reference/databases/mariadb.rst @@ -0,0 +1,36 @@ +MariaDB +========== + +The connection to MariaDB is based on the `mariadb `_ library. + +Dependency +---------- + +* mariadb + + +Parameters +---------- + +Required: + +* `user` is the username corresponding to the database +* `password` is the password for the above username for the database +* `database` is the database name +* `host` is the host name, IP address or the URL +* `port` is the port used to make the TCP/IP connection. + + +Create Connection +----------------- + +.. code-block:: text + + CREATE DATABASE mariadb_data WITH ENGINE = 'mariadb', PARAMETERS = { + "user" : "eva", + "password": "password", + "host": "127.0.0.1". + "port": "7567", + "database": "evadb" + }; + diff --git a/evadb/third_party/databases/interface.py b/evadb/third_party/databases/interface.py index f49403b92e..f0c3ee14ec 100644 --- a/evadb/third_party/databases/interface.py +++ b/evadb/third_party/databases/interface.py @@ -37,6 +37,8 @@ def get_database_handler(engine: str, **kwargs): return mod.SQLiteHandler(engine, **kwargs) elif engine == "mysql": return mod.MysqlHandler(engine, **kwargs) + elif engine == "mariadb": + return mod.MariaDbHandler(engine, **kwargs) else: raise NotImplementedError(f"Engine {engine} is not supported") diff --git a/evadb/third_party/databases/mariadb/__init__.py b/evadb/third_party/databases/mariadb/__init__.py new file mode 100644 index 0000000000..4a840209af --- /dev/null +++ b/evadb/third_party/databases/mariadb/__init__.py @@ -0,0 +1,15 @@ +# coding=utf-8 +# Copyright 2018-2023 EvaDB +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""mariadb integrations""" diff --git a/evadb/third_party/databases/mariadb/mariadb_handler.py b/evadb/third_party/databases/mariadb/mariadb_handler.py new file mode 100644 index 0000000000..8da40e0981 --- /dev/null +++ b/evadb/third_party/databases/mariadb/mariadb_handler.py @@ -0,0 +1,176 @@ +# coding=utf-8 +# Copyright 2018-2023 EvaDB +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import mariadb +import pandas as pd + +from evadb.third_party.databases.types import ( + DBHandler, + DBHandlerResponse, + DBHandlerStatus, +) + + +class MariaDbHandler(DBHandler): + + """ + Class for implementing the Maria DB handler as a backend store for + EvaDb. + """ + + def __init__(self, name: str, **kwargs): + """ + Initialize the handler. + Args: + name (str): name of the DB handler instance + **kwargs: arbitrary keyword arguments for establishing the connection. + """ + super().__init__(name) + self.host = kwargs.get("host") + self.port = kwargs.get("port") + self.user = kwargs.get("user") + self.password = kwargs.get("password") + self.database = kwargs.get("database") + + def connect(self): + """ + Establish connection to the database. + Returns: + DBHandlerStatus + """ + try: + self.connection = mariadb.connect( + host=self.host, + port=self.port, + user=self.user, + password=self.password, + database=self.database, + ) + # Auto commit is off by default. + self.connection.autocommit = True + return DBHandlerStatus(status=True) + except mariadb.Error as e: + return DBHandlerStatus(status=False, error=str(e)) + + def disconnect(self): + """ + Disconnect from the database. + """ + if self.connection: + self.connection.close() + + def check_connection(self) -> DBHandlerStatus: + """ + Method for checking the status of database connection. + Returns: + DBHandlerStatus + """ + if self.connection: + return DBHandlerStatus(status=True) + else: + return DBHandlerStatus(status=False, error="Not connected to the database.") + + def get_tables(self) -> DBHandlerResponse: + """ + Method to get the list of tables from database. + Returns: + DBHandlerStatus + """ + if not self.connection: + return DBHandlerResponse(data=None, error="Not connected to the database.") + + try: + query = f"SELECT table_name as 'table_name' FROM information_schema.tables WHERE table_schema='{self.database}'" + tables_df = pd.read_sql_query(query, self.connection) + return DBHandlerResponse(data=tables_df) + except mariadb.Error as e: + return DBHandlerResponse(data=None, error=str(e)) + + def get_columns(self, table_name: str) -> DBHandlerResponse: + """ + Method to retrieve the columns of the specified table from the database. + Args: + table_name (str): name of the table whose columns are to be retrieved. + Returns: + DBHandlerStatus + """ + if not self.connection: + return DBHandlerResponse(data=None, error="Not connected to the database.") + + try: + query = f"SELECT column_name as 'name', data_type as 'dtype' FROM information_schema.columns WHERE table_name='{table_name}'" + columns_df = pd.read_sql_query(query, self.connection) + columns_df["dtype"] = columns_df["dtype"].apply( + self._mariadb_to_python_types + ) + return DBHandlerResponse(data=columns_df) + except mariadb.Error as e: + return DBHandlerResponse(data=None, error=str(e)) + + def _fetch_results_as_df(self, cursor): + """ + Fetch results from the cursor for the executed query and return the + query results as dataframe. + """ + try: + res = cursor.fetchall() + res_df = pd.DataFrame(res, columns=[desc[0] for desc in cursor.description]) + return res_df + except mariadb.ProgrammingError as e: + if str(e) == "no results to fetch": + return pd.DataFrame({"status": ["success"]}) + raise e + + def execute_native_query(self, query_string: str) -> DBHandlerResponse: + """ + Executes the native query on the database. + Args: + query_string (str): query in native format + Returns: + DBHandlerResponse + """ + if not self.connection: + return DBHandlerResponse(data=None, error="Not connected to the database.") + + try: + cursor = self.connection.cursor() + cursor.execute(query_string) + return DBHandlerResponse(data=self._fetch_results_as_df(cursor)) + except mariadb.Error as e: + return DBHandlerResponse(data=None, error=str(e)) + + def _mariadb_to_python_types(self, mariadb_type: str): + mapping = { + "tinyint": int, + "smallint": int, + "mediumint": int, + "bigint": int, + "int": int, + "decimal": float, + "float": float, + "double": float, + "text": str, + "string literals": str, + "char": str, + "varchar": str, + "boolean": bool, + # Add more mappings as needed + } + + if mariadb_type in mapping: + return mapping[mariadb_type] + else: + raise Exception( + f"Unsupported column {mariadb_type} encountered in the MariaDB. Please raise a feature request!" + ) diff --git a/evadb/third_party/databases/mariadb/requirements.txt b/evadb/third_party/databases/mariadb/requirements.txt new file mode 100644 index 0000000000..45f92cdb79 --- /dev/null +++ b/evadb/third_party/databases/mariadb/requirements.txt @@ -0,0 +1 @@ +mariadb \ No newline at end of file diff --git a/test/third_party_tests/test_native_executor.py b/test/third_party_tests/test_native_executor.py index 7259f4ef03..7ff7a7b71f 100644 --- a/test/third_party_tests/test_native_executor.py +++ b/test/third_party_tests/test_native_executor.py @@ -169,6 +169,22 @@ def test_should_run_query_in_postgres(self): self._raise_error_on_multiple_creation() self._raise_error_on_invalid_connection() + def test_should_run_query_in_mariadb(self): + # Create database. + params = { + "user": "eva", + "password": "password", + "database": "evadb", + } + query = f"""CREATE DATABASE test_data_source + WITH ENGINE = "mariadb", + PARAMETERS = {params};""" + execute_query_fetch_all(self.evadb, query) + + # Test executions. + self._execute_native_query() + self._execute_evadb_query() + def test_should_run_query_in_sqlite(self): # Create database. params = { diff --git a/test/unit_tests/storage/test_mariadb_native_storage_engine.py b/test/unit_tests/storage/test_mariadb_native_storage_engine.py new file mode 100644 index 0000000000..36eba0a9a8 --- /dev/null +++ b/test/unit_tests/storage/test_mariadb_native_storage_engine.py @@ -0,0 +1,140 @@ +# coding=utf-8 +# Copyright 2018-2023 EvaDB +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import unittest +from test.util import get_evadb_for_testing +from unittest.mock import MagicMock, patch + +import pytest + +from evadb.catalog.models.utils import DatabaseCatalogEntry +from evadb.server.command_handler import execute_query_fetch_all + + +class NativeQueryResponse: + def __init__(self): + self.error = None + self.data = None + + +@pytest.mark.notparallel +class MariaDbStorageEngineTest(unittest.TestCase): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def get_mariadb_params(self): + return {"user": "eva", "password": "password", "database": "evadb"} + + def setUp(self): + connection_params = self.get_mariadb_params() + self.evadb = get_evadb_for_testing() + + sys.modules["mariadb"] = MagicMock() + + self.get_database_catalog_entry_patcher = patch( + "evadb.catalog.catalog_manager.CatalogManager.get_database_catalog_entry" + ) + self.get_database_catalog_entry_mock = ( + self.get_database_catalog_entry_patcher.start() + ) + + self.execute_native_query_patcher = patch( + "evadb.third_party.databases.mariadb.mariadb_handler.MariaDbHandler.execute_native_query" + ) + self.execute_native_query_mock = self.execute_native_query_patcher.start() + + self.connect_patcher = patch( + "evadb.third_party.databases.mariadb.mariadb_handler.MariaDbHandler.connect" + ) + self.connect_mock = self.connect_patcher.start() + + self.disconnect_patcher = patch( + "evadb.third_party.databases.mariadb.mariadb_handler.MariaDbHandler.disconnect" + ) + self.disconnect_mock = self.disconnect_patcher.start() + + # set return values + self.execute_native_query_mock.return_value = NativeQueryResponse() + self.get_database_catalog_entry_mock.return_value = DatabaseCatalogEntry( + name="test_data_source", + engine="mariadb", + params=connection_params, + row_id=1, + ) + + def tearDown(self): + self.get_database_catalog_entry_patcher.stop() + self.execute_native_query_patcher.stop() + self.connect_patcher.stop() + self.disconnect_patcher.stop() + + def test_execute_mariadb_select_query(self): + execute_query_fetch_all( + self.evadb, + """USE test_data_source { + SELECT * FROM test_table + }""", + ) + + self.connect_mock.assert_called_once() + self.execute_native_query_mock.assert_called_once() + self.get_database_catalog_entry_mock.assert_called_once() + self.disconnect_mock.assert_called_once() + + def test_execute_mariadb_insert_query(self): + execute_query_fetch_all( + self.evadb, + """USE test_data_source { + INSERT INTO test_table ( + name, age, comment + ) VALUES ( + 'val', 5, 'testing' + ) + }""", + ) + self.connect_mock.assert_called_once() + self.execute_native_query_mock.assert_called_once() + self.get_database_catalog_entry_mock.assert_called_once() + self.disconnect_mock.assert_called_once() + + def test_execute_mariadb_update_query(self): + execute_query_fetch_all( + self.evadb, + """USE test_data_source { + UPDATE test_table + SET comment = 'update' + WHERE age > 5 + }""", + ) + + self.connect_mock.assert_called_once() + self.execute_native_query_mock.assert_called_once() + self.get_database_catalog_entry_mock.assert_called_once() + self.disconnect_mock.assert_called_once() + + def test_execute_mariadb_delete_query(self): + execute_query_fetch_all( + self.evadb, + """USE test_data_source { + DELETE FROM test_table + WHERE age < 5 + }""", + ) + + self.connect_mock.assert_called_once() + self.execute_native_query_mock.assert_called_once() + self.get_database_catalog_entry_mock.assert_called_once() + self.disconnect_mock.assert_called_once()