From 6c6e06fa841cef742d49f696ff7e610d1fb99c27 Mon Sep 17 00:00:00 2001 From: auxten Date: Thu, 11 May 2023 02:30:57 +0000 Subject: [PATCH] Add test for issue #31 --- tests/run_all.py | 2 +- tests/test_issue31.py | 43 ++++++++++++++++++++++++++++++++++++++++--- 2 files changed, 41 insertions(+), 4 deletions(-) diff --git a/tests/run_all.py b/tests/run_all.py index 6baabb6d7cb..66e64e2a6e0 100755 --- a/tests/run_all.py +++ b/tests/run_all.py @@ -5,7 +5,7 @@ test_loader = unittest.TestLoader() test_suite = test_loader.discover('./') -test_runner = unittest.TextTestRunner() +test_runner = unittest.TextTestRunner(verbosity=2) ret = test_runner.run(test_suite) # if any test fails, exit with non-zero code diff --git a/tests/test_issue31.py b/tests/test_issue31.py index e83d5d931a8..1cd15e04712 100644 --- a/tests/test_issue31.py +++ b/tests/test_issue31.py @@ -2,6 +2,7 @@ import os import time +import hashlib import unittest import chdb import zipfile @@ -24,18 +25,50 @@ def download_and_extract(url, save_path): print("Done!") -@timeout(20, use_signals=False) +# @timeout(60, use_signals=False) + +import signal + + def payload(): now = time.time() res = chdb.query( 'select Name, count(*) cnt from file("organizations-2000000.csv", CSVWithNames) group by Name order by cnt desc', "CSV", ) - print(res.get_memview().tobytes().decode("utf-8")) + # calculate md5 of the result + hash_out = hashlib.md5(res.get_memview().tobytes()).hexdigest() + print("output length: ", len(res.get_memview().tobytes())) + if hash_out != "60833f6ba30f2892f1fda976b2088570": + print(res.get_memview().tobytes().decode("utf-8")) + raise Exception(f"md5 not match {hash_out}") used_time = time.time() - now print("used time: ", used_time) +class TimeoutTestRunner(unittest.TextTestRunner): + def __init__(self, timeout=60, *args, **kwargs): + super().__init__(*args, **kwargs) + self.timeout = timeout + + def run(self, test): + class TimeoutException(Exception): + pass + + def handler(signum, frame): + print("Timeout after {} seconds".format(self.timeout)) + raise TimeoutException("Timeout after {} seconds".format(self.timeout)) + + old_handler = signal.signal(signal.SIGALRM, handler) + signal.alarm(self.timeout) + + result = super().run(test) + + signal.alarm(0) + signal.signal(signal.SIGALRM, old_handler) + return result + + class TestAggOnCSVSpeed(unittest.TestCase): def setUp(self): download_and_extract(csv_url, "organizations-2000000.zip") @@ -44,9 +77,13 @@ def tearDown(self): os.remove("organizations-2000000.csv") os.remove("organizations-2000000.zip") - def test_agg(self): + def _test_agg(self, arg=None): payload() + def test_agg(self): + result = TimeoutTestRunner(timeout=20).run(self._test_agg) + self.assertTrue(result.wasSuccessful(), "Test failed: took too long to execute") + if __name__ == "__main__": unittest.main()