Skip to content

Commit

Permalink
Add test for issue #31
Browse files Browse the repository at this point in the history
  • Loading branch information
auxten committed Jun 7, 2024
1 parent f9c7fbc commit 6c6e06f
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 4 deletions.
2 changes: 1 addition & 1 deletion tests/run_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
test_loader = unittest.TestLoader()
test_suite = test_loader.discover('./')

test_runner = unittest.TextTestRunner()
test_runner = unittest.TextTestRunner(verbosity=2)
ret = test_runner.run(test_suite)

# if any test fails, exit with non-zero code
Expand Down
43 changes: 40 additions & 3 deletions tests/test_issue31.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import os
import time
import hashlib
import unittest
import chdb
import zipfile
Expand All @@ -24,18 +25,50 @@ def download_and_extract(url, save_path):
print("Done!")


@timeout(20, use_signals=False)
# @timeout(60, use_signals=False)

import signal


def payload():
now = time.time()
res = chdb.query(
'select Name, count(*) cnt from file("organizations-2000000.csv", CSVWithNames) group by Name order by cnt desc',
"CSV",
)
print(res.get_memview().tobytes().decode("utf-8"))
# calculate md5 of the result
hash_out = hashlib.md5(res.get_memview().tobytes()).hexdigest()
print("output length: ", len(res.get_memview().tobytes()))
if hash_out != "60833f6ba30f2892f1fda976b2088570":
print(res.get_memview().tobytes().decode("utf-8"))
raise Exception(f"md5 not match {hash_out}")
used_time = time.time() - now
print("used time: ", used_time)


class TimeoutTestRunner(unittest.TextTestRunner):
def __init__(self, timeout=60, *args, **kwargs):
super().__init__(*args, **kwargs)
self.timeout = timeout

def run(self, test):
class TimeoutException(Exception):
pass

def handler(signum, frame):
print("Timeout after {} seconds".format(self.timeout))
raise TimeoutException("Timeout after {} seconds".format(self.timeout))

old_handler = signal.signal(signal.SIGALRM, handler)
signal.alarm(self.timeout)

result = super().run(test)

signal.alarm(0)
signal.signal(signal.SIGALRM, old_handler)
return result


class TestAggOnCSVSpeed(unittest.TestCase):
def setUp(self):
download_and_extract(csv_url, "organizations-2000000.zip")
Expand All @@ -44,9 +77,13 @@ def tearDown(self):
os.remove("organizations-2000000.csv")
os.remove("organizations-2000000.zip")

def test_agg(self):
def _test_agg(self, arg=None):
payload()

def test_agg(self):
result = TimeoutTestRunner(timeout=20).run(self._test_agg)
self.assertTrue(result.wasSuccessful(), "Test failed: took too long to execute")


if __name__ == "__main__":
unittest.main()

0 comments on commit 6c6e06f

Please sign in to comment.