Skip to content

Commit

Permalink
refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
bojiang committed Feb 25, 2021
1 parent 2643868 commit 7b5d0fc
Show file tree
Hide file tree
Showing 5 changed files with 31 additions and 42 deletions.
7 changes: 1 addition & 6 deletions bentoml/adapters/base_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,7 @@ class BaseInputAdapter:
def __init__(self, http_input_example=None, **base_config):
self._config = base_config
self._http_input_example = http_input_example
if base_config.get('request_schema') is not None:
self.request_schema = base_config['request_schema']
self.custom_request_schema = base_config.get('request_schema')

@property
def config(self):
Expand All @@ -58,10 +57,6 @@ def request_schema(self):
"""
return {"application/json": {"schema": {"type": "object"}}}

@request_schema.setter
def request_schema(self, schema):
self.__dict__['request_schema'] = schema

@property
def pip_dependencies(self):
"""
Expand Down
6 changes: 5 additions & 1 deletion bentoml/service/inference_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,11 @@ def request_schema(self):
"""
:return: the HTTP API request schema in OpenAPI/Swagger format
"""
schema = self.input_adapter.request_schema
if self.input_adapter.custom_request_schema is None:
schema = self.input_adapter.request_schema
else:
schema = self.input_adapter.custom_request_schema

if schema.get('application/json'):
schema.get('application/json')[
'example'
Expand Down
22 changes: 11 additions & 11 deletions tests/integration/projects/general/tests/test_microbatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,15 @@ async def test_slow_server(host):
await asyncio.gather(*tasks)
assert time.time() - time_start < 12


@pytest.mark.asyncio
async def test_fast_server(host):
if not pytest.enable_microbatch:
pytest.skip()

A, B = 0.0002, 0.01
data = '{"a": %s, "b": %s}' % (A, B)

req_count = 100
tasks = tuple(
pytest.assert_request(
Expand All @@ -46,17 +55,8 @@ async def test_slow_server(host):
)
await asyncio.gather(*tasks)


@pytest.mark.asyncio
async def test_fast_server(host):
if not pytest.enable_microbatch:
pytest.skip()

A, B = 0.0002, 0.01
data = '{"a": %s, "b": %s}' % (A, B)

time_start = time.time()
req_count = 500
req_count = 200
tasks = tuple(
pytest.assert_request(
"POST",
Expand All @@ -70,4 +70,4 @@ async def test_fast_server(host):
for i in range(req_count)
)
await asyncio.gather(*tasks)
assert time.time() - time_start < 5
assert time.time() - time_start < 2
14 changes: 0 additions & 14 deletions tests/integration/projects/general_non_batch/tests/test_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,3 @@ async def test_api_server_meta(host):
await pytest.assert_request("GET", f"http://{host}/")
await pytest.assert_request("GET", f"http://{host}/healthz")
await pytest.assert_request("GET", f"http://{host}/docs.json")


@pytest.mark.asyncio
async def test_customized_request_schema(host):
def has_customized_schema(doc_bytes):
json_str = doc_bytes.decode()
return "field1" in json_str

await pytest.assert_request(
"GET",
f"http://{host}/docs.json",
headers=(("Content-Type", "application/json"),),
assert_data=has_customized_schema,
)
24 changes: 14 additions & 10 deletions tests/integration/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,21 @@

def _wait_until_api_server_ready(host_url, timeout, container=None, check_interval=1):
start_time = time.time()
proxy_handler = urllib.request.ProxyHandler({})
opener = urllib.request.build_opener(proxy_handler)
ex = None
while time.time() - start_time < timeout:
try:
if (
urllib.request.urlopen(f'http://{host_url}/healthz', timeout=1).status
== 200
):
break
if opener.open(f'http://{host_url}/healthz', timeout=1).status == 200:
return
elif container.status != "running":
break
else:
logger.info("Waiting for host %s to be ready..", host_url)
time.sleep(check_interval)
except Exception as e: # pylint:disable=broad-except
logger.info(f"'{e}', retrying to connect to the host {host_url}...")
logger.info(f"retrying to connect to the host {host_url}...")
ex = e
time.sleep(check_interval)
finally:
if container:
Expand All @@ -40,7 +41,8 @@ def _wait_until_api_server_ready(host_url, timeout, container=None, check_interv
logger.info(f">>> {log_record}")
else:
raise AssertionError(
f"Timed out waiting {timeout} seconds for Server {host_url} to be ready"
f"Timed out waiting {timeout} seconds for Server {host_url} to be ready, "
f"exception: {ex}"
)


Expand Down Expand Up @@ -148,6 +150,8 @@ def print_log(p):
) as p:
host_url = f"127.0.0.1:{port}"
threading.Thread(target=print_log, args=(p,), daemon=True).start()
_wait_until_api_server_ready(host_url, timeout=timeout)
yield host_url
p.terminate()
try:
_wait_until_api_server_ready(host_url, timeout=timeout)
yield host_url
finally:
p.terminate()

0 comments on commit 7b5d0fc

Please sign in to comment.