diff --git a/docs/_toc.yml b/docs/_toc.yml
index 6a37f98261..6739643319 100644
--- a/docs/_toc.yml
+++ b/docs/_toc.yml
@@ -71,6 +71,7 @@ parts:
- file: source/reference/databases/mariadb
- file: source/reference/databases/clickhouse
- file: source/reference/databases/github
+ - file: source/reference/databases/snowflake
- file: source/reference/vector_databases/index
title: Vector Databases
diff --git a/docs/source/reference/databases/snowflake.rst b/docs/source/reference/databases/snowflake.rst
new file mode 100644
index 0000000000..239389e186
--- /dev/null
+++ b/docs/source/reference/databases/snowflake.rst
@@ -0,0 +1,47 @@
+Snowflake
+==========
+
+The connection to Snowflake is based on the `snowflake-connector-python `_ library.
+
+Dependency
+----------
+
+* snowflake-connector-python
+
+Parameters
+----------
+
+Required:
+
+* `user` is the database user.
+* `password` is the snowflake account password.
+* `database` is the database name.
+* `warehouse` is the snowflake warehouse name.
+* `account` is the snowflake account number ( can be found in the url ).
+* `schema` is the schema name.
+
+
+.. warning::
+
+ Provide the parameters of an already running ``Snowflake`` Data Warehouse. EvaDB only connects to an existing ``Snowflake`` Data Warehouse.
+
+Create Connection
+-----------------
+
+.. code-block:: text
+
+ CREATE DATABASE snowflake_data WITH ENGINE = 'snowflake', PARAMETERS = {
+ "user": "",
+ "password": ""
+ "account": "",
+ "database": "EVADB",
+ "warehouse": "COMPUTE_WH",
+ "schema": "SAMPLE_DATA"
+ };
+
+.. warning::
+
+ | In Snowflake Terminology, ``Database`` and ``Schema`` refer to the following.
+ | A database is a logical grouping of schemas. Each database belongs to a single Snowflake account.
+ | A schema is a logical grouping of database objects (tables, views, etc.). Each schema belongs to a single database.
+
diff --git a/evadb/third_party/databases/interface.py b/evadb/third_party/databases/interface.py
index 5e30dc8220..e4cd86151c 100644
--- a/evadb/third_party/databases/interface.py
+++ b/evadb/third_party/databases/interface.py
@@ -44,6 +44,8 @@ def _get_database_handler(engine: str, **kwargs):
return mod.MariaDbHandler(engine, **kwargs)
elif engine == "clickhouse":
return mod.ClickHouseHandler(engine, **kwargs)
+ elif engine == "snowflake":
+ return mod.SnowFlakeDbHandler(engine, **kwargs)
elif engine == "github":
return mod.GithubHandler(engine, **kwargs)
else:
diff --git a/evadb/third_party/databases/snowflake/__init__.py b/evadb/third_party/databases/snowflake/__init__.py
new file mode 100644
index 0000000000..c881047fe1
--- /dev/null
+++ b/evadb/third_party/databases/snowflake/__init__.py
@@ -0,0 +1,15 @@
+# coding=utf-8
+# Copyright 2018-2023 EvaDB
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""snowflake integrations"""
diff --git a/evadb/third_party/databases/snowflake/requirements.txt b/evadb/third_party/databases/snowflake/requirements.txt
new file mode 100644
index 0000000000..c366baf101
--- /dev/null
+++ b/evadb/third_party/databases/snowflake/requirements.txt
@@ -0,0 +1,3 @@
+snowflake-connector-python
+pyarrow
+pandas
\ No newline at end of file
diff --git a/evadb/third_party/databases/snowflake/snowflake_handler.py b/evadb/third_party/databases/snowflake/snowflake_handler.py
new file mode 100644
index 0000000000..0b5ba4553d
--- /dev/null
+++ b/evadb/third_party/databases/snowflake/snowflake_handler.py
@@ -0,0 +1,184 @@
+# coding=utf-8
+# Copyright 2018-2023 EvaDB
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import datetime
+
+import pandas as pd
+import snowflake.connector
+
+from evadb.third_party.databases.types import (
+ DBHandler,
+ DBHandlerResponse,
+ DBHandlerStatus,
+)
+
+
+class SnowFlakeDbHandler(DBHandler):
+
+ """
+ Class for implementing the SnowFlake DB handler as a backend store for
+ EvaDB.
+ """
+
+ def __init__(self, name: str, **kwargs):
+ """
+ Initialize the handler.
+ Args:
+ name (str): name of the DB handler instance
+ **kwargs: arbitrary keyword arguments for establishing the connection.
+ """
+ super().__init__(name)
+ self.user = kwargs.get("user")
+ self.password = kwargs.get("password")
+ self.database = kwargs.get("database")
+ self.warehouse = kwargs.get("warehouse")
+ self.account = kwargs.get("account")
+ self.schema = kwargs.get("schema")
+
+ def connect(self):
+ """
+ Establish connection to the database.
+ Returns:
+ DBHandlerStatus
+ """
+ try:
+ self.connection = snowflake.connector.connect(
+ user=self.user,
+ password=self.password,
+ database=self.database,
+ warehouse=self.warehouse,
+ schema=self.schema,
+ account=self.account,
+ )
+ # Auto commit is off by default.
+ self.connection.autocommit = True
+ return DBHandlerStatus(status=True)
+ except snowflake.connector.errors.Error as e:
+ return DBHandlerStatus(status=False, error=str(e))
+
+ def disconnect(self):
+ """
+ Disconnect from the database.
+ """
+ if self.connection:
+ self.connection.close()
+
+ def check_connection(self) -> DBHandlerStatus:
+ """
+ Method for checking the status of database connection.
+ Returns:
+ DBHandlerStatus
+ """
+ if self.connection:
+ return DBHandlerStatus(status=True)
+ else:
+ return DBHandlerStatus(status=False, error="Not connected to the database.")
+
+ def get_tables(self) -> DBHandlerResponse:
+ """
+ Method to get the list of tables from database.
+ Returns:
+ DBHandlerStatus
+ """
+ if not self.connection:
+ return DBHandlerResponse(data=None, error="Not connected to the database.")
+
+ try:
+ query = f"SELECT table_name as table_name FROM information_schema.tables WHERE table_schema='{self.schema}'"
+ cursor = self.connection.cursor()
+ cursor.execute(query)
+ tables_df = self._fetch_results_as_df(cursor)
+ return DBHandlerResponse(data=tables_df)
+ except snowflake.connector.errors.Error as e:
+ return DBHandlerResponse(data=None, error=str(e))
+
+ def get_columns(self, table_name: str) -> DBHandlerResponse:
+ """
+ Method to retrieve the columns of the specified table from the database.
+ Args:
+ table_name (str): name of the table whose columns are to be retrieved.
+ Returns:
+ DBHandlerStatus
+ """
+ if not self.connection:
+ return DBHandlerResponse(data=None, error="Not connected to the database.")
+
+ try:
+ query = f"SELECT column_name as name, data_type as dtype FROM information_schema.columns WHERE table_name='{table_name}'"
+ cursor = self.connection.cursor()
+ cursor.execute(query)
+ columns_df = self._fetch_results_as_df(cursor)
+ columns_df["dtype"] = columns_df["dtype"].apply(
+ self._snowflake_to_python_types
+ )
+ return DBHandlerResponse(data=columns_df)
+ except snowflake.connector.errors.Error as e:
+ return DBHandlerResponse(data=None, error=str(e))
+
+ def _fetch_results_as_df(self, cursor):
+ """
+ Fetch results from the cursor for the executed query and return the
+ query results as dataframe.
+ """
+ try:
+ res = cursor.fetchall()
+ res_df = pd.DataFrame(
+ res, columns=[desc[0].lower() for desc in cursor.description]
+ )
+ return res_df
+ except snowflake.connector.errors.ProgrammingError as e:
+ if str(e) == "no results to fetch":
+ return pd.DataFrame({"status": ["success"]})
+ raise e
+
+ def execute_native_query(self, query_string: str) -> DBHandlerResponse:
+ """
+ Executes the native query on the database.
+ Args:
+ query_string (str): query in native format
+ Returns:
+ DBHandlerResponse
+ """
+ if not self.connection:
+ return DBHandlerResponse(data=None, error="Not connected to the database.")
+
+ try:
+ cursor = self.connection.cursor()
+ cursor.execute(query_string)
+ return DBHandlerResponse(data=self._fetch_results_as_df(cursor))
+ except snowflake.connector.errors.Error as e:
+ return DBHandlerResponse(data=None, error=str(e))
+
+ def _snowflake_to_python_types(self, snowflake_type: str):
+ mapping = {
+ "TEXT": str,
+ "NUMBER": int,
+ "INT": int,
+ "DECIMAL": float,
+ "STRING": str,
+ "CHAR": str,
+ "BOOLEAN": bool,
+ "BINARY": bytes,
+ "DATE": datetime.date,
+ "TIME": datetime.time,
+ "TIMESTAMP": datetime.datetime
+ # Add more mappings as needed
+ }
+
+ if snowflake_type in mapping:
+ return mapping[snowflake_type]
+ else:
+ raise Exception(
+ f"Unsupported column {snowflake_type} encountered in the snowflake. Please raise a feature request!"
+ )
diff --git a/test/third_party_tests/test_native_executor.py b/test/third_party_tests/test_native_executor.py
index 7eaf27cb85..879435f866 100644
--- a/test/third_party_tests/test_native_executor.py
+++ b/test/third_party_tests/test_native_executor.py
@@ -228,6 +228,28 @@ def test_should_run_query_in_clickhouse(self):
self._execute_native_query()
self._execute_evadb_query()
+ @pytest.mark.skip(
+ reason="Snowflake does not come with a free version of account, so integration test is not feasible"
+ )
+ def test_should_run_query_in_snowflake(self):
+ # Create database.
+ params = {
+ "user": "eva",
+ "password": "password",
+ "account": "account_number",
+ "database": "EVADB",
+ "schema": "SAMPLE_DATA",
+ "warehouse": "warehouse",
+ }
+ query = f"""CREATE DATABASE test_data_source
+ WITH ENGINE = "snowflake",
+ PARAMETERS = {params};"""
+ execute_query_fetch_all(self.evadb, query)
+
+ # Test executions.
+ self._execute_native_query()
+ self._execute_evadb_query()
+
def test_should_run_query_in_sqlite(self):
# Create database.
import os
diff --git a/test/unit_tests/storage/test_snowflake_native_storage_engine.py b/test/unit_tests/storage/test_snowflake_native_storage_engine.py
new file mode 100644
index 0000000000..50fcea23b3
--- /dev/null
+++ b/test/unit_tests/storage/test_snowflake_native_storage_engine.py
@@ -0,0 +1,144 @@
+# coding=utf-8
+# Copyright 2018-2023 EvaDB
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import sys
+import unittest
+from test.util import get_evadb_for_testing
+from unittest.mock import MagicMock, patch
+
+import pytest
+
+from evadb.catalog.models.utils import DatabaseCatalogEntry
+from evadb.server.command_handler import execute_query_fetch_all
+
+sys.modules["snowflake"] = MagicMock()
+sys.modules["snowflake.connector"] = MagicMock()
+
+
+class NativeQueryResponse:
+ def __init__(self):
+ self.error = None
+ self.data = None
+
+
+@pytest.mark.notparallel
+class SnowFlakeNativeStorageEngineTest(unittest.TestCase):
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ def get_snowflake_params(self):
+ return {
+ "database": "evadb.db",
+ }
+
+ def setUp(self):
+ connection_params = self.get_snowflake_params()
+ self.evadb = get_evadb_for_testing()
+
+ # Create all class level patches
+ self.get_database_catalog_entry_patcher = patch(
+ "evadb.catalog.catalog_manager.CatalogManager.get_database_catalog_entry"
+ )
+ self.get_database_catalog_entry_mock = (
+ self.get_database_catalog_entry_patcher.start()
+ )
+
+ self.execute_native_query_patcher = patch(
+ "evadb.third_party.databases.snowflake.snowflake_handler.SnowFlakeDbHandler.execute_native_query"
+ )
+ self.execute_native_query_mock = self.execute_native_query_patcher.start()
+
+ self.connect_patcher = patch(
+ "evadb.third_party.databases.snowflake.snowflake_handler.SnowFlakeDbHandler.connect"
+ )
+ self.connect_mock = self.connect_patcher.start()
+
+ self.disconnect_patcher = patch(
+ "evadb.third_party.databases.snowflake.snowflake_handler.SnowFlakeDbHandler.disconnect"
+ )
+ self.disconnect_mock = self.disconnect_patcher.start()
+
+ # set return values
+ self.execute_native_query_mock.return_value = NativeQueryResponse()
+ self.get_database_catalog_entry_mock.return_value = DatabaseCatalogEntry(
+ name="test_data_source",
+ engine="snowflake",
+ params=connection_params,
+ row_id=1,
+ )
+
+ def tearDown(self):
+ self.get_database_catalog_entry_patcher.stop()
+ self.execute_native_query_patcher.stop()
+ self.connect_patcher.stop()
+ self.disconnect_patcher.stop()
+
+ def test_execute_snowflake_select_query(self):
+ execute_query_fetch_all(
+ self.evadb,
+ """USE test_data_source {
+ SELECT * FROM test_table
+ }""",
+ )
+
+ self.connect_mock.assert_called_once()
+ self.execute_native_query_mock.assert_called_once()
+ self.get_database_catalog_entry_mock.assert_called_once()
+ self.disconnect_mock.assert_called_once()
+
+ def test_execute_snowflake_insert_query(self):
+ execute_query_fetch_all(
+ self.evadb,
+ """USE test_data_source {
+ INSERT INTO test_table (
+ name, age, comment
+ ) VALUES (
+ 'val', 5, 'testing'
+ )
+ }""",
+ )
+ self.connect_mock.assert_called_once()
+ self.execute_native_query_mock.assert_called_once()
+ self.get_database_catalog_entry_mock.assert_called_once()
+ self.disconnect_mock.assert_called_once()
+
+ def test_execute_snowflake_update_query(self):
+ execute_query_fetch_all(
+ self.evadb,
+ """USE test_data_source {
+ UPDATE test_table
+ SET comment = 'update'
+ WHERE age > 5
+ }""",
+ )
+
+ self.connect_mock.assert_called_once()
+ self.execute_native_query_mock.assert_called_once()
+ self.get_database_catalog_entry_mock.assert_called_once()
+ self.disconnect_mock.assert_called_once()
+
+ def test_execute_snowflake_delete_query(self):
+ execute_query_fetch_all(
+ self.evadb,
+ """USE test_data_source {
+ DELETE FROM test_table
+ WHERE age < 5
+ }""",
+ )
+
+ self.connect_mock.assert_called_once()
+ self.execute_native_query_mock.assert_called_once()
+ self.get_database_catalog_entry_mock.assert_called_once()
+ self.disconnect_mock.assert_called_once()