Skip to content

Commit

Permalink
Merge pull request #4 from slidoapp/gkaretka/update-tests-to-pytest
Browse files Browse the repository at this point in the history
Move tests from unstructured to pytest
  • Loading branch information
gkaretka authored Dec 21, 2023
2 parents 1d1c300 + 04381f8 commit 47cf232
Show file tree
Hide file tree
Showing 10 changed files with 181 additions and 123 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ dependencies = [
"pytest",
]
[tool.hatch.envs.default.scripts]
test = "pytest {args:tests}"
test = "pytest ./tests"
test-cov = "coverage run -m pytest {args:tests}"
cov-report = [
"- coverage combine",
Expand Down
16 changes: 8 additions & 8 deletions src/duckberg/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
"""Module containing services needed for executing queries with Duckdb + Iceberg."""

from typing import Optional

import duckdb
from pyarrow.lib import RecordBatchReader
from pyiceberg.catalog import Catalog, load_catalog, load_rest
from duckberg.exceptions import TableNotInCatalogException
from pyiceberg.expressions import AlwaysTrue

from duckberg.exceptions import TableNotInCatalogException
from duckberg.sqlparser import DuckBergSQLParser
from duckberg.table import DuckBergTable, TableWithAlias

Expand Down Expand Up @@ -70,7 +71,9 @@ def list_partitions(self, table: str):

return t.partitions

def select(self, sql: str, table: str = None, partition_filter: str = None, sql_params: [str] = None) -> RecordBatchReader:
def select(
self, sql: str, table: str = None, partition_filter: str = None, sql_params: [str] = None
) -> RecordBatchReader:
if table is not None and partition_filter is not None:
return self._select_old(sql, table, partition_filter, sql_params)

Expand All @@ -94,12 +97,9 @@ def select(self, sql: str, table: str = None, partition_filter: str = None, sql_
if sql_params is None:
return self.duckdb_connection.execute(sql).fetch_record_batch(self.batch_size_rows)
else:
return (
self.duckdb_connection.execute(sql, parameters=sql_params)
.fetch_record_batch(self.batch_size_rows)
)

def _select_old(self, sql: str, table: str, partition_filter: str, sql_params: [str] = None):
return self.duckdb_connection.execute(sql, parameters=sql_params).fetch_record_batch(self.batch_size_rows)

def _select_old(self, sql: str, table: str, partition_filter: str, sql_params: [str] = None):
table_data_scan_as_arrow = self.tables[table].scan(row_filter=partition_filter).to_arrow()
self.duckdb_connection.register(table, table_data_scan_as_arrow)

Expand Down
7 changes: 4 additions & 3 deletions src/duckberg/sqlparser.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import sqlparse

from duckberg.table import TableWithAlias
from pyiceberg.expressions import *
from pyiceberg.expressions import parser

from duckberg.table import TableWithAlias


class DuckBergSQLParser:
def parse_first_query(self, sql: str) -> sqlparse.sql.Statement:
Expand Down Expand Up @@ -46,4 +46,5 @@ def extract_tables(self, parsed_sql: sqlparse.sql.Statement) -> list[TableWithAl

def extract_where_conditions(self, where_statement: list[sqlparse.sql.Where]):
comparison = sqlparse.sql.TokenList(where_statement[1:])
return parser.parse(str(comparison))
where_condition = str(comparison).replace('"', "'") # revert from double to single
return parser.parse(where_condition)
4 changes: 2 additions & 2 deletions src/duckberg/table.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import sqlparse
from pyiceberg.catalog import Catalog
from pyiceberg.expressions import BooleanExpression
from pyiceberg.io import FileIO
from pyiceberg.table import Table
from pyiceberg.table.metadata import TableMetadata
from pyiceberg.typedef import Identifier
from pyiceberg.expressions import BooleanExpression
import sqlparse


class DuckBergTable(Table):
Expand Down
37 changes: 0 additions & 37 deletions tests/duckberg-sample.py

This file was deleted.

40 changes: 0 additions & 40 deletions tests/sqlparser/basic_selects.py

This file was deleted.

58 changes: 58 additions & 0 deletions tests/sqlparser/test_selects.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import pytest

from duckberg.sqlparser import DuckBergSQLParser


@pytest.fixture
def get_parser() -> DuckBergSQLParser:
return DuckBergSQLParser()


def test_basic_select_1(get_parser):
sql = """SELECT * FROM this_is_awesome_table"""
sql_parsed = get_parser.parse_first_query(sql=sql)
res = get_parser.extract_tables(sql_parsed)
assert len(res) == 1
assert list(map(lambda x: str(x), res)) == ["this_is_awesome_table (None)"]


def test_basic_select_2(get_parser):
sql = """SELECT * FROM 'this_is_awesome_table'"""
sql_parsed = get_parser.parse_first_query(sql=sql)
print(str(sql_parsed.tokens))
res = get_parser.extract_tables(sql_parsed)
print(res)
assert len(res) == 1
assert list(map(lambda x: str(x), res)) == ["this_is_awesome_table (None)"]


def test_basic_select_3(get_parser):
sql = """SELECT * FROM this_is_awesome_table, second_awesome_table"""
sql_parsed = get_parser.parse_first_query(sql=sql)
res = get_parser.extract_tables(sql_parsed)
assert len(res) == 2
assert list(map(lambda x: str(x), res)) == ["this_is_awesome_table (None)", "second_awesome_table (None)"]


def test_basic_select_4(get_parser):
sql = """SELECT * FROM (SELECT * FROM (SELECT * FROM this_is_awesome_table))"""
sql_parsed = get_parser.parse_first_query(sql=sql)
res = get_parser.extract_tables(sql_parsed)
assert len(res) == 1
assert list(map(lambda x: str(x), res)) == ["this_is_awesome_table (None)"]


def test_basic_select_5(get_parser):
sql = """SELECT * FROM (SELECT * FROM (SELECT * FROM this_is_awesome_table), second_awesome_table)"""
sql_parsed = get_parser.parse_first_query(sql=sql)
res = get_parser.extract_tables(sql_parsed)
assert len(res) == 2
assert list(map(lambda x: str(x), res)) == ["this_is_awesome_table (None)", "second_awesome_table (None)"]


def test_basic_select_6(get_parser):
sql = """SELECT * FROM (SELECT * FROM (SELECT * FROM this_is_awesome_table tiat, second_awesome_table))"""
sql_parsed = get_parser.parse_first_query(sql=sql)
res = get_parser.extract_tables(sql_parsed)
assert len(res) == 2
assert list(map(lambda x: str(x), res)) == ["this_is_awesome_table (tiat)", "second_awesome_table (None)"]
60 changes: 60 additions & 0 deletions tests/sqlparser/test_where.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import pytest

from duckberg.sqlparser import DuckBergSQLParser


@pytest.fixture
def get_parser():
return DuckBergSQLParser()


def test_select_where_1(get_parser):
sql = """SELECT * FROM this_is_awesome_table WHERE a > 15"""
sql_parsed = get_parser.parse_first_query(sql=sql)
res = get_parser.extract_tables(sql_parsed)
res_where = str(res[0].comparisons)
assert "GreaterThan(term=Reference(name='a'), literal=LongLiteral(15))" == res_where


def test_select_where_2(get_parser):
sql = """SELECT * FROM this_is_awesome_table WHERE a > 15 AND a < 16"""
sql_parsed = get_parser.parse_first_query(sql=sql)
res = get_parser.extract_tables(sql_parsed)
res_where = str(res[0].comparisons)
assert (
"And(left=GreaterThan(term=Reference(name='a'), literal=LongLiteral(15)), right=LessThan(term=Reference(name='a'), literal=LongLiteral(16)))"
== res_where
)


def test_select_where_3(get_parser):
sql = """SELECT * FROM this_is_awesome_table WHERE (a > 15 AND a < 16) OR c > 15"""
sql_parsed = get_parser.parse_first_query(sql=sql)
res = get_parser.extract_tables(sql_parsed)
res_where = str(res[0].comparisons)
assert (
"Or(left=And(left=GreaterThan(term=Reference(name='a'), literal=LongLiteral(15)), right=LessThan(term=Reference(name='a'), literal=LongLiteral(16))), right=GreaterThan(term=Reference(name='c'), literal=LongLiteral(15)))"
== res_where
)


def test_select_where_4(get_parser):
sql = """SELECT * FROM this_is_awesome_table WHERE (b = "test string" AND a < 16) OR c > 15"""
sql_parsed = get_parser.parse_first_query(sql=sql)
res = get_parser.extract_tables(sql_parsed)
res_where = str(res[0].comparisons)
assert (
"Or(left=And(left=EqualTo(term=Reference(name='b'), literal=literal('test string')), right=LessThan(term=Reference(name='a'), literal=LongLiteral(16))), right=GreaterThan(term=Reference(name='c'), literal=LongLiteral(15)))"
== res_where
)


def test_select_where_4(get_parser):
sql = """SELECT * FROM this_is_awesome_table WHERE (b = "test string" AND column = '108e6307-f23a-4e10-9e38-1866d58b4355') OR c > 15"""
sql_parsed = get_parser.parse_first_query(sql=sql)
res = get_parser.extract_tables(sql_parsed)
res_where = str(res[0].comparisons)
assert (
"Or(left=And(left=EqualTo(term=Reference(name='b'), literal=literal('test string')), right=EqualTo(term=Reference(name='column'), literal=literal('108e6307-f23a-4e10-9e38-1866d58b4355'))), right=GreaterThan(term=Reference(name='c'), literal=LongLiteral(15)))"
== res_where
)
32 changes: 0 additions & 32 deletions tests/sqlparser/where_selects.py

This file was deleted.

48 changes: 48 additions & 0 deletions tests/test_duckberg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import pytest

from duckberg import DuckBerg


@pytest.fixture
def get_duckberg() -> DuckBerg:
MINIO_URI = "http://localhost:9000/"
MINIO_USER = "admin"
MINIO_PASSWORD = "password"

catalog_config: dict[str, str] = {
"type": "rest",
"uri": "http://localhost:8181/",
"credentials": "admin:password",
"s3.endpoint": MINIO_URI,
"s3.access-key-id": MINIO_USER,
"s3.secret-access-key": MINIO_PASSWORD,
}

catalog_name = "warehouse"
return DuckBerg(catalog_name=catalog_name, catalog_config=catalog_config)


def test_list_tables(get_duckberg):
tables = get_duckberg.list_tables()
assert len(tables) == 1


def test_select_1(get_duckberg):
# New way of quering data without partition filter
query: str = "SELECT count(*) FROM (SELECT * FROM 'nyc.taxis' WHERE trip_distance > 40 ORDER BY tolls_amount DESC)"
df = get_duckberg.select(sql=query).read_pandas()
assert df["count_star()"][0] == 2614


def test_select_2(get_duckberg):
# New way of quering data
query: str = "SELECT count(*) FROM (SELECT * FROM 'nyc.taxis' WHERE payment_type = 1 AND trip_distance > 40 ORDER BY tolls_amount DESC)"
df = get_duckberg.select(sql=query).read_pandas()
assert df["count_star()"][0] == 1673


def test_select_3(get_duckberg):
# Old way of quering data
query: str = "SELECT count(*) FROM (SELECT * FROM 'nyc.taxis' WHERE payment_type = 1 AND trip_distance > 40 ORDER BY tolls_amount DESC)"
df = get_duckberg.select(sql=query, table="nyc.taxis", partition_filter="payment_type = 1").read_pandas()
assert df["count_star()"][0] == 1673

0 comments on commit 47cf232

Please sign in to comment.