Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use argparse in taichi/main.py to implement #600 #601

Merged
merged 5 commits into from
Mar 18, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ script:
- export TAICHI_REPO_DIR=$TRAVIS_BUILD_DIR
- export PYTHONPATH=$TAICHI_REPO_DIR/python
- export PATH=$TAICHI_REPO_DIR/bin:$PATH
- ti test_verbose && cd python && $PYTHON build.py try_upload
- ti test -v && cd python && $PYTHON build.py try_upload

env:
global:
Expand Down
15 changes: 13 additions & 2 deletions python/taichi/lang/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,12 @@ def benchmark(func, repeat=100, args=()):
elapsed = time.time() - t
return elapsed / repeat

wanted_archs = None

def set_wanted_archs(archs):
global wanted_archs
wanted_archs = archs

def supported_archs():
import taichi as ti
archs = [ti.core.host_arch()]
Expand All @@ -223,6 +229,11 @@ def supported_archs():
archs.append(metal)
if ti.core.with_opengl():
archs.append(opengl)
if wanted_archs is not None:
archs, old_archs = [], archs
for arch in old_archs:
if ti.core.arch_name(arch) in wanted_archs:
archs.append(arch)
return archs

class _ArchCheckers(object):
Expand Down Expand Up @@ -286,11 +297,11 @@ def all_archs(test):
#
# Example usage:
#
# ti.archs_excluding(ti.cuda, ti.metal)
# @ti.archs_excluding(ti.cuda, ti.metal)
# def test_xx():
# ...
#
# ti.archs_excluding(ti.cuda, default_fp=ti.f64)
# @ti.archs_excluding(ti.cuda, default_fp=ti.f64)
archibate marked this conversation as resolved.
Show resolved Hide resolved
# def test_yy():
# ...
def archs_excluding(*excluded_archs, **kwargs):
Expand Down
52 changes: 28 additions & 24 deletions python/taichi/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import shutil
import time
import random
import argparse
from taichi.tools.video import make_video, interpolate_frames, mp4_to_gif, scale_video, crop_video, accelerate_video


Expand Down Expand Up @@ -30,7 +31,7 @@ def test_python(test_files=(), verbose=False):
# run all the tests
args = [test_dir]
if verbose:
args += ['-s']
args += ['-s', '-v']
if len(test_files) == 0 or len(test_files) > 4:
if int(pytest.main([os.path.join(root_dir, 'misc/empty_pytest.py'), '-n1'])) == 0: # if pytest has xdist
try:
Expand All @@ -57,7 +58,19 @@ def test_cpp(test_files=()):
return int(task.run(*test_files))


def make_argument_parser():
parser = argparse.ArgumentParser()
parser.add_argument('action', help='See `ti help` for more details')
parser.add_argument('-v', '--verbose', action='store_true', help='Run with verbose outputs')
parser.add_argument('-a', '--arch', help='Specify arch(s) to run test on, e.g. -a opengl,metal')
parser.add_argument('files', nargs='*', help='Files to be tested')
return parser


def main(debug=False):
parser = make_argument_parser()
args = parser.parse_args()

lines = []
print()
lines.append(u' *******************************************')
Expand All @@ -73,16 +86,15 @@ def main(debug=False):
print(u'\n'.join(lines))
print()
import taichi as ti
if args.arch is not None:
ti.set_wanted_archs(args.arch.split(','))

argc = len(sys.argv)
if argc == 1 or sys.argv[1] == 'help':
if argc == 1 or args.action == 'help':
yuanming-hu marked this conversation as resolved.
Show resolved Hide resolved
print(
" Usage: ti run [task name] |-> Run a specific task\n"
" ti benchmark |-> Run performance benchmark\n"
" ti test |-> Run all tests\n"
" ti test_verbose |-> Run all tests with verbose outputs\n"
" ti test_python |-> Run python tests\n"
" ti test_cpp |-> Run cpp tests\n"
" ti test |-> Run all the tests\n"
" ti format |-> Reformat modified source files\n"
" ti format_all |-> Reformat all source files\n"
" ti build |-> Build C++ files\n"
Expand All @@ -94,8 +106,8 @@ def main(debug=False):
" ti doc |-> Build documentation\n"
" ti release |-> Make source code release\n"
" ti debug [script.py] |-> Debug script\n")
exit(0)
mode = sys.argv[1]
return 0
mode = args.action

t = time.time()
if mode.endswith('.py'):
Expand All @@ -104,32 +116,24 @@ def main(debug=False):
elif mode == "run":
if argc <= 2:
print("Please specify [task name], e.g. test_math")
exit(-1)
return -1
name = sys.argv[2]
task = ti.Task(name)
task.run(*sys.argv[3:])
elif mode == "debug":
ti.core.set_core_trigger_gdb_when_crash(True)
if argc <= 2:
print("Please specify [file name], e.g. render.py")
exit(-1)
return -1
name = sys.argv[2]
with open(name) as script:
script = script.read()
exec(script, {'__name__': '__main__'})
elif mode == "test_python":
return test_python(test_files=sys.argv[2:])
elif mode == "test_cpp":
return test_cpp(test_files=sys.argv[2:])
elif mode == "test":
if test_python(test_files=sys.argv[2:]) != 0:
return -1
if len(sys.argv) <= 2:
return test_cpp()
elif mode == "test_verbose":
if test_python(test_files=sys.argv[2:], verbose=True) != 0:
return -1
return test_cpp()
ret = test_python(test_files=args.files, verbose=args.verbose)
if ret: return -1
ret = test_cpp(test_files=args.files)
return ret
elif mode == "build":
ti.core.build()
elif mode == "format":
Expand Down Expand Up @@ -179,7 +183,7 @@ def main(debug=False):
elif mode == "video_crop":
if len(sys.argv) != 7:
print('Usage: ti video_crop fn x_begin x_end y_begin y_end')
exit(-1)
return -1
input_fn = sys.argv[2]
assert input_fn[-4:] == '.mp4'
output_fn = input_fn[:-4] + '-cropped.mp4'
Expand All @@ -191,7 +195,7 @@ def main(debug=False):
elif mode == "video_speed":
if len(sys.argv) != 4:
print('Usage: ti video_speed fn speed_up_factor')
exit(-1)
return -1
input_fn = sys.argv[2]
assert input_fn[-4:] == '.mp4'
output_fn = input_fn[:-4] + '-sped.mp4'
Expand Down