From 9c6675723c1692a9fdbbd76ec1681f811f99e0e1 Mon Sep 17 00:00:00 2001 From: Zhao Liang Date: Fri, 25 Nov 2022 08:46:55 +0800 Subject: [PATCH] [test] Add jupyter notebook to tests (#6717) --- requirements_test.txt | 1 + tests/python/test_ipython.ipynb | 68 +++++++++++++++++++++++++++++++++ tests/run_tests.py | 3 +- 3 files changed, 71 insertions(+), 1 deletion(-) create mode 100644 tests/python/test_ipython.ipynb diff --git a/requirements_test.txt b/requirements_test.txt index a5ca479fdc82f..375222501501c 100644 --- a/requirements_test.txt +++ b/requirements_test.txt @@ -11,3 +11,4 @@ matplotlib cffi scipy setproctitle +nbmake diff --git a/tests/python/test_ipython.ipynb b/tests/python/test_ipython.ipynb new file mode 100644 index 0000000000000..a895423861032 --- /dev/null +++ b/tests/python/test_ipython.ipynb @@ -0,0 +1,68 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "51b3b00e", + "metadata": {}, + "source": [ + "# Count primes below a given bound" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "83062572", + "metadata": {}, + "outputs": [], + "source": [ + "import taichi as ti\n", + "\n", + "def test_kernel_print():\n", + " N = 10000\n", + " ti.init()\n", + " @ti.func\n", + " def is_prime(n: int):\n", + " result = True\n", + " for k in range(2, int(n**0.5) + 1):\n", + " if n % k == 0:\n", + " result = False\n", + " break\n", + " return result\n", + "\n", + " @ti.kernel\n", + " def count_primes(n: int) -> int:\n", + " count = 0\n", + " for k in range(2, n):\n", + " if is_prime(k):\n", + " count += 1\n", + "\n", + " return count\n", + "\n", + " count_primes(N)\n", + " \n", + "test_kernel_print()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.6" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/tests/run_tests.py b/tests/run_tests.py index 25518563970e7..00c3f4599c100 100644 --- a/tests/run_tests.py +++ b/tests/run_tests.py @@ -153,7 +153,7 @@ def _test_python(args): # auto-complete file names if not f.startswith('test_'): f = 'test_' + f - if not f.endswith('.py'): + if not (f.endswith('.py') or f.endswith('.ipynb')): f = f + '.py' file = os.path.join(test_dir, f) has_tests = False @@ -164,6 +164,7 @@ def _test_python(args): else: # run all the tests pytest_args = [test_dir] + pytest_args += ['--nbmake'] if args.verbose: pytest_args += ['-v'] if args.rerun: