diff --git a/evadb/executor/use_executor.py b/evadb/executor/use_executor.py index e71f1dab3..f8f66df7a 100644 --- a/evadb/executor/use_executor.py +++ b/evadb/executor/use_executor.py @@ -14,18 +14,16 @@ # limitations under the License. from typing import Iterator -import pandas as pd -from sqlalchemy import create_engine - -from evadb.catalog.catalog_utils import generate_sqlalchemy_conn_str from evadb.database import EvaDBDatabase from evadb.executor.abstract_executor import AbstractExecutor +from evadb.executor.executor_utils import ExecutorError from evadb.models.storage.batch import Batch -from evadb.plan_nodes.native_plan import SQLAlchemyPlan +from evadb.parser.use_statement import UseStatement +from evadb.third_party.databases.interface import get_database_handler class UseExecutor(AbstractExecutor): - def __init__(self, db: EvaDBDatabase, node: SQLAlchemyPlan): + def __init__(self, db: EvaDBDatabase, node: UseStatement): super().__init__(db, node) self._database_name = node.database_name self._query_string = node.query_string @@ -35,16 +33,16 @@ def exec(self, *args, **kwargs) -> Iterator[Batch]: self._database_name ) - conn_str = generate_sqlalchemy_conn_str( + handler = get_database_handler( db_catalog_entry.engine, - db_catalog_entry.params, + **db_catalog_entry.params, ) - engine = create_engine(conn_str) + handler.connect() + resp = handler.execute_native_query(self._query_string) + handler.disconnect() - with engine.connect() as con: - if "SELECT" in self._query_string or "select" in self._query_string: - yield Batch(pd.read_sql(self._query_string, engine)) - else: - con.execute(self._query_string) - yield Batch(pd.DataFrame({"status": ["Ok"]})) + if resp.error is None: + return Batch(resp.data) + else: + raise ExecutorError(resp.error) diff --git a/evadb/third_party/databases/__init__.py b/evadb/third_party/databases/__init__.py new file mode 100644 index 000000000..59a62859b --- /dev/null +++ b/evadb/third_party/databases/__init__.py @@ -0,0 +1,15 @@ +# coding=utf-8 +# Copyright 2018-2023 EvaDB +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Database integrations""" diff --git a/evadb/third_party/databases/interface.py b/evadb/third_party/databases/interface.py new file mode 100644 index 000000000..321004334 --- /dev/null +++ b/evadb/third_party/databases/interface.py @@ -0,0 +1,64 @@ +# coding=utf-8 +# Copyright 2018-2023 EvaDB +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import importlib +import os + +import pip + +INSTALL_CACHE = [] + + +def get_database_handler(engine: str, **kwargs): + """ + Return the database handler. User should modify this function for + their new integrated handlers. + """ + + # Dynamically install dependencies. + dynamic_install(engine) + + # Dynamically import the top module. + mod = dynamic_import(engine) + + if engine == "postgres": + return mod.PostgresHandler(engine, **kwargs) + else: + raise NotImplementedError(f"Engine {engine} is not supported") + + +def dynamic_install(handler_dir): + """ + Dynamically install package from requirements.txt. + """ + + # Skip installation + if handler_dir in INSTALL_CACHE: + return + + INSTALL_CACHE.append(handler_dir) + + req_file = os.path.join(handler_dir, "requirements.txt") + if os.path.isfile(req_file): + with open(req_file) as f: + for package in f.read().splitlines(): + if hasattr(pip, "main"): + pip.main(["install", package]) + else: + pip._internal.main(["install", package]) + + +def dynamic_import(handler_dir): + import_path = f"evadb.third_party.databases.{handler_dir}.{handler_dir}_handler" + return importlib.import_module(import_path) diff --git a/evadb/third_party/databases/postgres/__init__.py b/evadb/third_party/databases/postgres/__init__.py new file mode 100644 index 000000000..28f4fa393 --- /dev/null +++ b/evadb/third_party/databases/postgres/__init__.py @@ -0,0 +1,15 @@ +# coding=utf-8 +# Copyright 2018-2023 EvaDB +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""postgres integrations""" diff --git a/evadb/third_party/databases/postgres/postgres_handler.py b/evadb/third_party/databases/postgres/postgres_handler.py new file mode 100644 index 000000000..4721536e2 --- /dev/null +++ b/evadb/third_party/databases/postgres/postgres_handler.py @@ -0,0 +1,110 @@ +# coding=utf-8 +# Copyright 2018-2023 EvaDB +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pandas as pd +import psycopg2 + +from evadb.third_party.databases.types import ( + DBHandler, + DBHandlerResponse, + DBHandlerStatus, +) + + +class PostgresHandler(DBHandler): + def __init__(self, name: str, **kwargs): + super().__init__(name) + self.host = kwargs.get("host") + self.port = kwargs.get("port") + self.user = kwargs.get("user") + self.password = kwargs.get("password") + self.database = kwargs.get("database") + + def connect(self): + try: + self.connection = psycopg2.connect( + host=self.host, + port=self.port, + user=self.user, + password=self.password, + database=self.database, + ) + self.connection.autocommit = True + return DBHandlerStatus(status=True) + except psycopg2.Error as e: + return DBHandlerStatus(status=False, error=str(e)) + + def disconnect(self): + if self.connection: + self.connection.close() + + def check_connection(self) -> DBHandlerStatus: + if self.connection: + return DBHandlerStatus(status=True) + else: + return DBHandlerStatus(status=False, error="Not connected to the database.") + + def get_tables(self) -> DBHandlerResponse: + if not self.connection: + return DBHandlerResponse(data=None, error="Not connected to the database.") + + try: + query = "SELECT table_name FROM information_schema.tables WHERE table_schema NOT IN ('information_schema', 'pg_catalog')" + tables_df = pd.read_sql_query(query, self.connection) + return DBHandlerResponse(data=tables_df) + except psycopg2.Error as e: + return DBHandlerResponse(data=None, error=str(e)) + + def get_columns(self, table_name: str) -> DBHandlerResponse: + if not self.connection: + return DBHandlerResponse(data=None, error="Not connected to the database.") + + try: + query = f"SELECT column_name FROM information_schema.columns WHERE table_name='{table_name}'" + columns_df = pd.read_sql_query(query, self.connection) + return DBHandlerResponse(data=columns_df) + except psycopg2.Error as e: + return DBHandlerResponse(data=None, error=str(e)) + + def _fetch_results_as_df(self, cursor): + """ + This is currently the only clean solution that we have found so far. + Reference to Postgres API: https://www.psycopg.org/docs/cursor.html#fetch + + In short, currently there is no very clean programming way to differentiate + CREATE, INSERT, SELECT. CREATE and INSERT do not return any result, so calling + fetchall() on those will yield a programming error. Cursor has an attribute + rowcount, but it indicates # of rows that are affected. In that case, for both + INSERT and SELECT rowcount is not 0, so we also cannot use this API to + differentiate INSERT and SELECT. + """ + try: + res = cursor.fetchall() + res_df = pd.DataFrame(res, columns=[desc[0] for desc in cursor.description]) + return res_df + except psycopg2.ProgrammingError as e: + if str(e) == "no results to fetch": + return pd.DataFrame({"status": ["success"]}) + raise e + + def execute_native_query(self, query_string: str) -> DBHandlerResponse: + if not self.connection: + return DBHandlerResponse(data=None, error="Not connected to the database.") + + try: + cursor = self.connection.cursor() + cursor.execute(query_string) + return DBHandlerResponse(data=self._fetch_results_as_df(cursor)) + except psycopg2.Error as e: + return DBHandlerResponse(data=None, error=str(e)) diff --git a/evadb/third_party/databases/postgres/requirements.txt b/evadb/third_party/databases/postgres/requirements.txt new file mode 100644 index 000000000..658130bb2 --- /dev/null +++ b/evadb/third_party/databases/postgres/requirements.txt @@ -0,0 +1 @@ +psycopg2 diff --git a/evadb/third_party/databases/types.py b/evadb/third_party/databases/types.py new file mode 100644 index 000000000..bdedb3a4b --- /dev/null +++ b/evadb/third_party/databases/types.py @@ -0,0 +1,128 @@ +# coding=utf-8 +# Copyright 2018-2023 EvaDB +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass + +import pandas as pd + + +@dataclass +class DBHandlerResponse: + """ + Represents the response from a database handler containing data and an optional error message. + + Attributes: + data (pd.DataFrame): A Pandas DataFrame containing the data retrieved from the database. + error (str, optional): An optional error message indicating any issues encountered during the operation. + """ + + data: pd.DataFrame + error: str = None + + +@dataclass +class DBHandlerStatus: + """ + Represents the status of a database handler operation, along with an optional error message. + + Attributes: + status (bool): A boolean indicating the success (True) or failure (False) of the operation. + error (str, optional): An optional error message providing details about any errors that occurred. + """ + + status: bool + error: str = None + + +class DBHandler: + """ + Base class for handling database operations. + + Args: + name (str): The name associated with the database handler instance. + """ + + def __init__(self, name: str): + self.name = name + + def connect(self): + """ + Establishes a connection to the database. + + Raises: + NotImplementedError: This method should be implemented in derived classes. + """ + raise NotImplementedError() + + def disconnect(self): + """ + Disconnects from the database. + + This method can be overridden in derived classes to perform specific disconnect actions. + """ + raise NotImplementedError() + + def check_connection(self) -> DBHandlerStatus: + """ + Checks the status of the database connection. + + Returns: + DBHandlerStatus: An instance of DBHandlerStatus indicating the connection status. + + Raises: + NotImplementedError: This method should be implemented in derived classes. + """ + raise NotImplementedError() + + def get_tables(self) -> DBHandlerResponse: + """ + Retrieves the list of tables from the database. + + Returns: + DBHandlerResponse: An instance of DBHandlerResponse containing the list of tables or an error message. Data is in a pandas DataFrame. + + Raises: + NotImplementedError: This method should be implemented in derived classes. + """ + raise NotImplementedError() + + def get_columns(self, table_name: str) -> DBHandlerResponse: + """ + Retrieves the columns of a specified table from the database. + + Args: + table_name (str): The name of the table for which to retrieve columns. + + Returns: + DBHandlerResponse: An instance of DBHandlerResponse containing the columns or an error message. Data is in a pandas DataFrame. + + Raises: + NotImplementedError: This method should be implemented in derived classes. + """ + raise NotImplementedError() + + def execute_native_query(self, query_string: str) -> DBHandlerResponse: + """ + Executes the query through the handler's database engine. + + Args: + query_string (str): The string representation of the native query. + + Returns: + DBHandlerResponse: An instance of DBHandlerResponse containing the columns or an error message. Data is in a pandas DataFrame. + + Raises: + NotImplementedError: This method should be implemented in derived classes. + """ + raise NotImplementedError()