diff --git a/docs/_toc.yml b/docs/_toc.yml index 6a37f98261..6739643319 100644 --- a/docs/_toc.yml +++ b/docs/_toc.yml @@ -71,6 +71,7 @@ parts: - file: source/reference/databases/mariadb - file: source/reference/databases/clickhouse - file: source/reference/databases/github + - file: source/reference/databases/snowflake - file: source/reference/vector_databases/index title: Vector Databases diff --git a/docs/source/reference/databases/snowflake.rst b/docs/source/reference/databases/snowflake.rst new file mode 100644 index 0000000000..239389e186 --- /dev/null +++ b/docs/source/reference/databases/snowflake.rst @@ -0,0 +1,47 @@ +Snowflake +========== + +The connection to Snowflake is based on the `snowflake-connector-python `_ library. + +Dependency +---------- + +* snowflake-connector-python + +Parameters +---------- + +Required: + +* `user` is the database user. +* `password` is the snowflake account password. +* `database` is the database name. +* `warehouse` is the snowflake warehouse name. +* `account` is the snowflake account number ( can be found in the url ). +* `schema` is the schema name. + + +.. warning:: + + Provide the parameters of an already running ``Snowflake`` Data Warehouse. EvaDB only connects to an existing ``Snowflake`` Data Warehouse. + +Create Connection +----------------- + +.. code-block:: text + + CREATE DATABASE snowflake_data WITH ENGINE = 'snowflake', PARAMETERS = { + "user": "", + "password": "" + "account": "", + "database": "EVADB", + "warehouse": "COMPUTE_WH", + "schema": "SAMPLE_DATA" + }; + +.. warning:: + + | In Snowflake Terminology, ``Database`` and ``Schema`` refer to the following. + | A database is a logical grouping of schemas. Each database belongs to a single Snowflake account. + | A schema is a logical grouping of database objects (tables, views, etc.). Each schema belongs to a single database. + diff --git a/evadb/third_party/databases/interface.py b/evadb/third_party/databases/interface.py index 5e30dc8220..e4cd86151c 100644 --- a/evadb/third_party/databases/interface.py +++ b/evadb/third_party/databases/interface.py @@ -44,6 +44,8 @@ def _get_database_handler(engine: str, **kwargs): return mod.MariaDbHandler(engine, **kwargs) elif engine == "clickhouse": return mod.ClickHouseHandler(engine, **kwargs) + elif engine == "snowflake": + return mod.SnowFlakeDbHandler(engine, **kwargs) elif engine == "github": return mod.GithubHandler(engine, **kwargs) else: diff --git a/evadb/third_party/databases/snowflake/__init__.py b/evadb/third_party/databases/snowflake/__init__.py new file mode 100644 index 0000000000..c881047fe1 --- /dev/null +++ b/evadb/third_party/databases/snowflake/__init__.py @@ -0,0 +1,15 @@ +# coding=utf-8 +# Copyright 2018-2023 EvaDB +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""snowflake integrations""" diff --git a/evadb/third_party/databases/snowflake/requirements.txt b/evadb/third_party/databases/snowflake/requirements.txt new file mode 100644 index 0000000000..c366baf101 --- /dev/null +++ b/evadb/third_party/databases/snowflake/requirements.txt @@ -0,0 +1,3 @@ +snowflake-connector-python +pyarrow +pandas \ No newline at end of file diff --git a/evadb/third_party/databases/snowflake/snowflake_handler.py b/evadb/third_party/databases/snowflake/snowflake_handler.py new file mode 100644 index 0000000000..0b5ba4553d --- /dev/null +++ b/evadb/third_party/databases/snowflake/snowflake_handler.py @@ -0,0 +1,184 @@ +# coding=utf-8 +# Copyright 2018-2023 EvaDB +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import datetime + +import pandas as pd +import snowflake.connector + +from evadb.third_party.databases.types import ( + DBHandler, + DBHandlerResponse, + DBHandlerStatus, +) + + +class SnowFlakeDbHandler(DBHandler): + + """ + Class for implementing the SnowFlake DB handler as a backend store for + EvaDB. + """ + + def __init__(self, name: str, **kwargs): + """ + Initialize the handler. + Args: + name (str): name of the DB handler instance + **kwargs: arbitrary keyword arguments for establishing the connection. + """ + super().__init__(name) + self.user = kwargs.get("user") + self.password = kwargs.get("password") + self.database = kwargs.get("database") + self.warehouse = kwargs.get("warehouse") + self.account = kwargs.get("account") + self.schema = kwargs.get("schema") + + def connect(self): + """ + Establish connection to the database. + Returns: + DBHandlerStatus + """ + try: + self.connection = snowflake.connector.connect( + user=self.user, + password=self.password, + database=self.database, + warehouse=self.warehouse, + schema=self.schema, + account=self.account, + ) + # Auto commit is off by default. + self.connection.autocommit = True + return DBHandlerStatus(status=True) + except snowflake.connector.errors.Error as e: + return DBHandlerStatus(status=False, error=str(e)) + + def disconnect(self): + """ + Disconnect from the database. + """ + if self.connection: + self.connection.close() + + def check_connection(self) -> DBHandlerStatus: + """ + Method for checking the status of database connection. + Returns: + DBHandlerStatus + """ + if self.connection: + return DBHandlerStatus(status=True) + else: + return DBHandlerStatus(status=False, error="Not connected to the database.") + + def get_tables(self) -> DBHandlerResponse: + """ + Method to get the list of tables from database. + Returns: + DBHandlerStatus + """ + if not self.connection: + return DBHandlerResponse(data=None, error="Not connected to the database.") + + try: + query = f"SELECT table_name as table_name FROM information_schema.tables WHERE table_schema='{self.schema}'" + cursor = self.connection.cursor() + cursor.execute(query) + tables_df = self._fetch_results_as_df(cursor) + return DBHandlerResponse(data=tables_df) + except snowflake.connector.errors.Error as e: + return DBHandlerResponse(data=None, error=str(e)) + + def get_columns(self, table_name: str) -> DBHandlerResponse: + """ + Method to retrieve the columns of the specified table from the database. + Args: + table_name (str): name of the table whose columns are to be retrieved. + Returns: + DBHandlerStatus + """ + if not self.connection: + return DBHandlerResponse(data=None, error="Not connected to the database.") + + try: + query = f"SELECT column_name as name, data_type as dtype FROM information_schema.columns WHERE table_name='{table_name}'" + cursor = self.connection.cursor() + cursor.execute(query) + columns_df = self._fetch_results_as_df(cursor) + columns_df["dtype"] = columns_df["dtype"].apply( + self._snowflake_to_python_types + ) + return DBHandlerResponse(data=columns_df) + except snowflake.connector.errors.Error as e: + return DBHandlerResponse(data=None, error=str(e)) + + def _fetch_results_as_df(self, cursor): + """ + Fetch results from the cursor for the executed query and return the + query results as dataframe. + """ + try: + res = cursor.fetchall() + res_df = pd.DataFrame( + res, columns=[desc[0].lower() for desc in cursor.description] + ) + return res_df + except snowflake.connector.errors.ProgrammingError as e: + if str(e) == "no results to fetch": + return pd.DataFrame({"status": ["success"]}) + raise e + + def execute_native_query(self, query_string: str) -> DBHandlerResponse: + """ + Executes the native query on the database. + Args: + query_string (str): query in native format + Returns: + DBHandlerResponse + """ + if not self.connection: + return DBHandlerResponse(data=None, error="Not connected to the database.") + + try: + cursor = self.connection.cursor() + cursor.execute(query_string) + return DBHandlerResponse(data=self._fetch_results_as_df(cursor)) + except snowflake.connector.errors.Error as e: + return DBHandlerResponse(data=None, error=str(e)) + + def _snowflake_to_python_types(self, snowflake_type: str): + mapping = { + "TEXT": str, + "NUMBER": int, + "INT": int, + "DECIMAL": float, + "STRING": str, + "CHAR": str, + "BOOLEAN": bool, + "BINARY": bytes, + "DATE": datetime.date, + "TIME": datetime.time, + "TIMESTAMP": datetime.datetime + # Add more mappings as needed + } + + if snowflake_type in mapping: + return mapping[snowflake_type] + else: + raise Exception( + f"Unsupported column {snowflake_type} encountered in the snowflake. Please raise a feature request!" + ) diff --git a/test/third_party_tests/test_native_executor.py b/test/third_party_tests/test_native_executor.py index 7eaf27cb85..879435f866 100644 --- a/test/third_party_tests/test_native_executor.py +++ b/test/third_party_tests/test_native_executor.py @@ -228,6 +228,28 @@ def test_should_run_query_in_clickhouse(self): self._execute_native_query() self._execute_evadb_query() + @pytest.mark.skip( + reason="Snowflake does not come with a free version of account, so integration test is not feasible" + ) + def test_should_run_query_in_snowflake(self): + # Create database. + params = { + "user": "eva", + "password": "password", + "account": "account_number", + "database": "EVADB", + "schema": "SAMPLE_DATA", + "warehouse": "warehouse", + } + query = f"""CREATE DATABASE test_data_source + WITH ENGINE = "snowflake", + PARAMETERS = {params};""" + execute_query_fetch_all(self.evadb, query) + + # Test executions. + self._execute_native_query() + self._execute_evadb_query() + def test_should_run_query_in_sqlite(self): # Create database. import os diff --git a/test/unit_tests/storage/test_snowflake_native_storage_engine.py b/test/unit_tests/storage/test_snowflake_native_storage_engine.py new file mode 100644 index 0000000000..50fcea23b3 --- /dev/null +++ b/test/unit_tests/storage/test_snowflake_native_storage_engine.py @@ -0,0 +1,144 @@ +# coding=utf-8 +# Copyright 2018-2023 EvaDB +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import unittest +from test.util import get_evadb_for_testing +from unittest.mock import MagicMock, patch + +import pytest + +from evadb.catalog.models.utils import DatabaseCatalogEntry +from evadb.server.command_handler import execute_query_fetch_all + +sys.modules["snowflake"] = MagicMock() +sys.modules["snowflake.connector"] = MagicMock() + + +class NativeQueryResponse: + def __init__(self): + self.error = None + self.data = None + + +@pytest.mark.notparallel +class SnowFlakeNativeStorageEngineTest(unittest.TestCase): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def get_snowflake_params(self): + return { + "database": "evadb.db", + } + + def setUp(self): + connection_params = self.get_snowflake_params() + self.evadb = get_evadb_for_testing() + + # Create all class level patches + self.get_database_catalog_entry_patcher = patch( + "evadb.catalog.catalog_manager.CatalogManager.get_database_catalog_entry" + ) + self.get_database_catalog_entry_mock = ( + self.get_database_catalog_entry_patcher.start() + ) + + self.execute_native_query_patcher = patch( + "evadb.third_party.databases.snowflake.snowflake_handler.SnowFlakeDbHandler.execute_native_query" + ) + self.execute_native_query_mock = self.execute_native_query_patcher.start() + + self.connect_patcher = patch( + "evadb.third_party.databases.snowflake.snowflake_handler.SnowFlakeDbHandler.connect" + ) + self.connect_mock = self.connect_patcher.start() + + self.disconnect_patcher = patch( + "evadb.third_party.databases.snowflake.snowflake_handler.SnowFlakeDbHandler.disconnect" + ) + self.disconnect_mock = self.disconnect_patcher.start() + + # set return values + self.execute_native_query_mock.return_value = NativeQueryResponse() + self.get_database_catalog_entry_mock.return_value = DatabaseCatalogEntry( + name="test_data_source", + engine="snowflake", + params=connection_params, + row_id=1, + ) + + def tearDown(self): + self.get_database_catalog_entry_patcher.stop() + self.execute_native_query_patcher.stop() + self.connect_patcher.stop() + self.disconnect_patcher.stop() + + def test_execute_snowflake_select_query(self): + execute_query_fetch_all( + self.evadb, + """USE test_data_source { + SELECT * FROM test_table + }""", + ) + + self.connect_mock.assert_called_once() + self.execute_native_query_mock.assert_called_once() + self.get_database_catalog_entry_mock.assert_called_once() + self.disconnect_mock.assert_called_once() + + def test_execute_snowflake_insert_query(self): + execute_query_fetch_all( + self.evadb, + """USE test_data_source { + INSERT INTO test_table ( + name, age, comment + ) VALUES ( + 'val', 5, 'testing' + ) + }""", + ) + self.connect_mock.assert_called_once() + self.execute_native_query_mock.assert_called_once() + self.get_database_catalog_entry_mock.assert_called_once() + self.disconnect_mock.assert_called_once() + + def test_execute_snowflake_update_query(self): + execute_query_fetch_all( + self.evadb, + """USE test_data_source { + UPDATE test_table + SET comment = 'update' + WHERE age > 5 + }""", + ) + + self.connect_mock.assert_called_once() + self.execute_native_query_mock.assert_called_once() + self.get_database_catalog_entry_mock.assert_called_once() + self.disconnect_mock.assert_called_once() + + def test_execute_snowflake_delete_query(self): + execute_query_fetch_all( + self.evadb, + """USE test_data_source { + DELETE FROM test_table + WHERE age < 5 + }""", + ) + + self.connect_mock.assert_called_once() + self.execute_native_query_mock.assert_called_once() + self.get_database_catalog_entry_mock.assert_called_once() + self.disconnect_mock.assert_called_once()