Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SnowFlake Integration for EvaDB #1289

Merged
merged 10 commits into from
Oct 20, 2023
2 changes: 2 additions & 0 deletions evadb/third_party/databases/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ def _get_database_handler(engine: str, **kwargs):
return mod.MariaDbHandler(engine, **kwargs)
elif engine == "clickhouse":
return mod.ClickHouseHandler(engine, **kwargs)
elif engine == "snowflake":
return mod.SnowFlakeDbHandler(engine, **kwargs)
elif engine == "github":
return mod.GithubHandler(engine, **kwargs)
else:
Expand Down
15 changes: 15 additions & 0 deletions evadb/third_party/databases/snowflake/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# coding=utf-8
# Copyright 2018-2023 EvaDB
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""snowflake integrations"""
3 changes: 3 additions & 0 deletions evadb/third_party/databases/snowflake/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
snowflake-connector-python
pyarrow
pandas
184 changes: 184 additions & 0 deletions evadb/third_party/databases/snowflake/snowflake_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
# coding=utf-8
# Copyright 2018-2023 EvaDB
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import datetime

import pandas as pd
import snowflake.connector

from evadb.third_party.databases.types import (
DBHandler,
DBHandlerResponse,
DBHandlerStatus,
)


class SnowFlakeDbHandler(DBHandler):

"""
Class for implementing the SnowFlake DB handler as a backend store for
EvaDB.
"""

def __init__(self, name: str, **kwargs):
"""
Initialize the handler.
Args:
name (str): name of the DB handler instance
**kwargs: arbitrary keyword arguments for establishing the connection.
"""
super().__init__(name)
self.user = kwargs.get("user")
self.password = kwargs.get("password")
self.database = kwargs.get("database")
self.warehouse = kwargs.get("warehouse")
self.account = kwargs.get("account")
self.schema = kwargs.get("schema")

def connect(self):
"""
Establish connection to the database.
Returns:
DBHandlerStatus
"""
try:
self.connection = snowflake.connector.connect(
user=self.user,
password=self.password,
database=self.database,
warehouse=self.warehouse,
schema=self.schema,
account=self.account,
)
# Auto commit is off by default.
self.connection.autocommit = True
return DBHandlerStatus(status=True)
except snowflake.connector.errors.Error as e:
return DBHandlerStatus(status=False, error=str(e))

def disconnect(self):
"""
Disconnect from the database.
"""
if self.connection:
self.connection.close()

def check_connection(self) -> DBHandlerStatus:
"""
Method for checking the status of database connection.
Returns:
DBHandlerStatus
"""
if self.connection:
return DBHandlerStatus(status=True)
else:
return DBHandlerStatus(status=False, error="Not connected to the database.")

def get_tables(self) -> DBHandlerResponse:
"""
Method to get the list of tables from database.
Returns:
DBHandlerStatus
"""
if not self.connection:
return DBHandlerResponse(data=None, error="Not connected to the database.")

try:
query = f"SELECT table_name as table_name FROM information_schema.tables WHERE table_schema='{self.schema}'"
cursor = self.connection.cursor()
cursor.execute(query)
tables_df = self._fetch_results_as_df(cursor)
return DBHandlerResponse(data=tables_df)
except snowflake.connector.errors.Error as e:
return DBHandlerResponse(data=None, error=str(e))

def get_columns(self, table_name: str) -> DBHandlerResponse:
"""
Method to retrieve the columns of the specified table from the database.
Args:
table_name (str): name of the table whose columns are to be retrieved.
Returns:
DBHandlerStatus
"""
if not self.connection:
return DBHandlerResponse(data=None, error="Not connected to the database.")

try:
query = f"SELECT column_name as name, data_type as dtype FROM information_schema.columns WHERE table_name='{table_name}'"
cursor = self.connection.cursor()
cursor.execute(query)
columns_df = self._fetch_results_as_df(cursor)
columns_df["dtype"] = columns_df["dtype"].apply(
self._snowflake_to_python_types
)
return DBHandlerResponse(data=columns_df)
except snowflake.connector.errors.Error as e:
return DBHandlerResponse(data=None, error=str(e))

def _fetch_results_as_df(self, cursor):
"""
Fetch results from the cursor for the executed query and return the
query results as dataframe.
"""
try:
res = cursor.fetchall()
res_df = pd.DataFrame(
res, columns=[desc[0].lower() for desc in cursor.description]
)
return res_df
except snowflake.connector.errors.ProgrammingError as e:
if str(e) == "no results to fetch":
return pd.DataFrame({"status": ["success"]})
raise e

def execute_native_query(self, query_string: str) -> DBHandlerResponse:
"""
Executes the native query on the database.
Args:
query_string (str): query in native format
Returns:
DBHandlerResponse
"""
if not self.connection:
return DBHandlerResponse(data=None, error="Not connected to the database.")

try:
cursor = self.connection.cursor()
cursor.execute(query_string)
return DBHandlerResponse(data=self._fetch_results_as_df(cursor))
except snowflake.connector.errors.Error as e:
return DBHandlerResponse(data=None, error=str(e))

def _snowflake_to_python_types(self, snowflake_type: str):
mapping = {
"TEXT": str,
"NUMBER": int,
"INT": int,
"DECIMAL": float,
"STRING": str,
"CHAR": str,
"BOOLEAN": bool,
"BINARY": bytes,
"DATE": datetime.date,
"TIME": datetime.time,
"TIMESTAMP": datetime.datetime
# Add more mappings as needed
}

if snowflake_type in mapping:
return mapping[snowflake_type]
else:
raise Exception(
f"Unsupported column {snowflake_type} encountered in the snowflake. Please raise a feature request!"
)
20 changes: 20 additions & 0 deletions test/third_party_tests/test_native_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,26 @@ def test_should_run_query_in_clickhouse(self):
self._execute_native_query()
self._execute_evadb_query()

def test_should_run_query_in_snowflake(self):
# Create database.
params = {
"user": "eva",
"password": "password",
"account": "account_number",
"database": "EVADB",
"schema": "SAMPLE_DATA",
"warehouse": "warehouse",
}
xzdandy marked this conversation as resolved.
Show resolved Hide resolved
query = f"""CREATE DATABASE test_data_source
WITH ENGINE = "snowflake",
PARAMETERS = {params};"""
execute_query_fetch_all(self.evadb, query)

# Test executions.
self._execute_native_query()
self._execute_evadb_query()


def test_should_run_query_in_sqlite(self):
# Create database.
import os
Expand Down
144 changes: 144 additions & 0 deletions test/unit_tests/storage/test_snowflake_native_storage_engine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
# coding=utf-8
# Copyright 2018-2023 EvaDB
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import sys
import unittest
from test.util import get_evadb_for_testing
from unittest.mock import MagicMock, patch

import pytest

from evadb.catalog.models.utils import DatabaseCatalogEntry
from evadb.server.command_handler import execute_query_fetch_all

sys.modules["snowflake"] = MagicMock()
sys.modules["snowflake.connector"] = MagicMock()


class NativeQueryResponse:
def __init__(self):
self.error = None
self.data = None


@pytest.mark.notparallel
class SnowFlakeNativeStorageEngineTest(unittest.TestCase):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

def get_snowflake_params(self):
return {
"database": "evadb.db",
}

def setUp(self):
connection_params = self.get_snowflake_params()
self.evadb = get_evadb_for_testing()

# Create all class level patches
self.get_database_catalog_entry_patcher = patch(
"evadb.catalog.catalog_manager.CatalogManager.get_database_catalog_entry"
)
self.get_database_catalog_entry_mock = (
self.get_database_catalog_entry_patcher.start()
)

self.execute_native_query_patcher = patch(
"evadb.third_party.databases.snowflake.snowflake_handler.SnowFlakeDbHandler.execute_native_query"
)
self.execute_native_query_mock = self.execute_native_query_patcher.start()

self.connect_patcher = patch(
"evadb.third_party.databases.snowflake.snowflake_handler.SnowFlakeDbHandler.connect"
)
self.connect_mock = self.connect_patcher.start()

self.disconnect_patcher = patch(
"evadb.third_party.databases.snowflake.snowflake_handler.SnowFlakeDbHandler.disconnect"
)
self.disconnect_mock = self.disconnect_patcher.start()

# set return values
self.execute_native_query_mock.return_value = NativeQueryResponse()
self.get_database_catalog_entry_mock.return_value = DatabaseCatalogEntry(
name="test_data_source",
engine="snowflake",
params=connection_params,
row_id=1,
)

def tearDown(self):
self.get_database_catalog_entry_patcher.stop()
self.execute_native_query_patcher.stop()
self.connect_patcher.stop()
self.disconnect_patcher.stop()

def test_execute_snowflake_select_query(self):
execute_query_fetch_all(
self.evadb,
"""USE test_data_source {
SELECT * FROM test_table
}""",
)

self.connect_mock.assert_called_once()
self.execute_native_query_mock.assert_called_once()
self.get_database_catalog_entry_mock.assert_called_once()
self.disconnect_mock.assert_called_once()

def test_execute_snowflake_insert_query(self):
execute_query_fetch_all(
self.evadb,
"""USE test_data_source {
INSERT INTO test_table (
name, age, comment
) VALUES (
'val', 5, 'testing'
)
}""",
)
self.connect_mock.assert_called_once()
self.execute_native_query_mock.assert_called_once()
self.get_database_catalog_entry_mock.assert_called_once()
self.disconnect_mock.assert_called_once()

def test_execute_snowflake_update_query(self):
execute_query_fetch_all(
self.evadb,
"""USE test_data_source {
UPDATE test_table
SET comment = 'update'
WHERE age > 5
}""",
)

self.connect_mock.assert_called_once()
self.execute_native_query_mock.assert_called_once()
self.get_database_catalog_entry_mock.assert_called_once()
self.disconnect_mock.assert_called_once()

def test_execute_snowflake_delete_query(self):
execute_query_fetch_all(
self.evadb,
"""USE test_data_source {
DELETE FROM test_table
WHERE age < 5
}""",
)

self.connect_mock.assert_called_once()
self.execute_native_query_mock.assert_called_once()
self.get_database_catalog_entry_mock.assert_called_once()
self.disconnect_mock.assert_called_once()