Skip to content

Commit

Permalink
Code format
Browse files Browse the repository at this point in the history
  • Loading branch information
PGZXB committed Mar 17, 2022
1 parent 1f1a8ab commit 4007837
Showing 1 changed file with 31 additions and 22 deletions.
53 changes: 31 additions & 22 deletions tests/python/test_offline_cache.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import functools
from genericpath import exists
import math
from os import listdir, path, remove, rmdir

import pytest
from genericpath import exists

import taichi as ti
from tests import test_utils
Expand Down Expand Up @@ -131,18 +131,22 @@ def _test_offline_cache_for_a_kernel(curr_arch, kernel, args, result):
def _test_closing_offline_cache_for_a_kernel(curr_arch, kernel, args, result):
count_of_cache_file = len(listdir(tmp_offline_cache_file_path))

ti.init(arch=curr_arch, enable_fallback=False, offline_cache_file_path=tmp_offline_cache_file_path)
ti.init(arch=curr_arch,
enable_fallback=False,
offline_cache_file_path=tmp_offline_cache_file_path)
res1 = kernel(*args)
assert len(listdir(tmp_offline_cache_file_path)
) - count_of_cache_file == 0 * cache_files_num_per_kernel

ti.init(arch=curr_arch, enable_fallback=False, offline_cache_file_path=tmp_offline_cache_file_path)

ti.init(arch=curr_arch,
enable_fallback=False,
offline_cache_file_path=tmp_offline_cache_file_path)
assert len(listdir(tmp_offline_cache_file_path)
) - count_of_cache_file == 0 * cache_files_num_per_kernel
res2 = kernel(*args)
assert res1 == test_utils.approx(result) and res1 == test_utils.approx(
res2)

ti.reset()
assert len(listdir(tmp_offline_cache_file_path)
) - count_of_cache_file == 0 * cache_files_num_per_kernel
Expand Down Expand Up @@ -190,11 +194,10 @@ def compute_y():
assert y[None] == 12.0
assert x.grad[None] == 12.0


ti.init(arch=curr_arch, enable_fallback=False, **ext_init_options)
helper()
assert len(listdir(tmp_offline_cache_file_path)
) - count_of_cache_file == 0 * cache_files_num_per_kernel
) - count_of_cache_file == 0 * cache_files_num_per_kernel

ti.init(arch=curr_arch, enable_fallback=False, **ext_init_options)
assert len(listdir(tmp_offline_cache_file_path)
Expand Down Expand Up @@ -264,17 +267,19 @@ def helper():
assert x[None] == test_utils.approx(6.28)
assert y[None] == test_utils.approx(7.28)


assert len(listdir(tmp_offline_cache_file_path)) - count_of_cache_file == 0 * cache_files_num_per_kernel
assert len(listdir(tmp_offline_cache_file_path)
) - count_of_cache_file == 0 * cache_files_num_per_kernel
ti.init(arch=curr_arch, enable_fallback=False, **ext_init_options)
helper()

ti.init(arch=curr_arch, enable_fallback=False, **ext_init_options)
assert len(listdir(tmp_offline_cache_file_path)) - count_of_cache_file == 4 * cache_files_num_per_kernel
assert len(listdir(tmp_offline_cache_file_path)
) - count_of_cache_file == 4 * cache_files_num_per_kernel
helper()

ti.reset()
assert len(listdir(tmp_offline_cache_file_path)) - count_of_cache_file == 4 * cache_files_num_per_kernel
assert len(listdir(tmp_offline_cache_file_path)
) - count_of_cache_file == 4 * cache_files_num_per_kernel


@pytest.mark.parametrize('curr_arch', supported_archs_offline_cache)
Expand All @@ -293,17 +298,20 @@ def helper():
assert a[2][1] == 1
assert a[3][4] == 4
assert a[4][9] == 9

assert len(listdir(tmp_offline_cache_file_path)) - count_of_cache_file == 0 * cache_files_num_per_kernel

assert len(listdir(tmp_offline_cache_file_path)
) - count_of_cache_file == 0 * cache_files_num_per_kernel
ti.init(arch=curr_arch, enable_fallback=False, **ext_init_options)
helper()

ti.init(arch=curr_arch, enable_fallback=False, **ext_init_options)
assert len(listdir(tmp_offline_cache_file_path)) - count_of_cache_file == 2 * cache_files_num_per_kernel
assert len(listdir(tmp_offline_cache_file_path)
) - count_of_cache_file == 2 * cache_files_num_per_kernel
helper()

ti.reset()
assert len(listdir(tmp_offline_cache_file_path)) - count_of_cache_file == 2 * cache_files_num_per_kernel
assert len(listdir(tmp_offline_cache_file_path)
) - count_of_cache_file == 2 * cache_files_num_per_kernel


@pytest.mark.parametrize('curr_arch', supported_archs_offline_cache)
Expand All @@ -313,18 +321,19 @@ def test_calling_many_kernels(curr_arch):

def helper():
for kernel, args, get_res in simple_kernels_to_test:
assert(kernel(*args) == test_utils.approx(get_res(*args)))

assert (kernel(*args) == test_utils.approx(get_res(*args)))

ti.init(arch=curr_arch, enable_fallback=False, **ext_init_options)
helper()
assert len(listdir(tmp_offline_cache_file_path)
) - count_of_cache_file == 0 * cache_files_num_per_kernel
) - count_of_cache_file == 0 * cache_files_num_per_kernel

ti.init(arch=curr_arch, enable_fallback=False, **ext_init_options)
assert len(listdir(tmp_offline_cache_file_path)
) - count_of_cache_file == len(simple_kernels_to_test) * cache_files_num_per_kernel
assert len(
listdir(tmp_offline_cache_file_path)) - count_of_cache_file == len(
simple_kernels_to_test) * cache_files_num_per_kernel
helper()
ti.reset()
assert len(listdir(tmp_offline_cache_file_path)
) - count_of_cache_file == len(simple_kernels_to_test) * cache_files_num_per_kernel
assert len(
listdir(tmp_offline_cache_file_path)) - count_of_cache_file == len(
simple_kernels_to_test) * cache_files_num_per_kernel

0 comments on commit 4007837

Please sign in to comment.