diff --git a/bentoml/adapters/base_input.py b/bentoml/adapters/base_input.py index 53f4b9e7d0b..60663a338ca 100644 --- a/bentoml/adapters/base_input.py +++ b/bentoml/adapters/base_input.py @@ -43,6 +43,7 @@ class BaseInputAdapter: def __init__(self, http_input_example=None, **base_config): self._config = base_config self._http_input_example = http_input_example + self.custom_request_schema = base_config.get('request_schema') @property def config(self): diff --git a/bentoml/service/inference_api.py b/bentoml/service/inference_api.py index b7152322c8a..d88dba6f86a 100644 --- a/bentoml/service/inference_api.py +++ b/bentoml/service/inference_api.py @@ -203,7 +203,11 @@ def request_schema(self): """ :return: the HTTP API request schema in OpenAPI/Swagger format """ - schema = self.input_adapter.request_schema + if self.input_adapter.custom_request_schema is None: + schema = self.input_adapter.request_schema + else: + schema = self.input_adapter.custom_request_schema + if schema.get('application/json'): schema.get('application/json')[ 'example' diff --git a/tests/integration/projects/general/service.py b/tests/integration/projects/general/service.py index e79c7d2ffd0..d7565b4bf94 100644 --- a/tests/integration/projects/general/service.py +++ b/tests/integration/projects/general/service.py @@ -70,6 +70,23 @@ def predict_json(self, input_datas): def customezed_route(self, input_datas): return input_datas + CUSTOM_SCHEMA = { + "application/json": { + "schema": { + "type": "object", + "required": ["field1", "field2"], + "properties": { + "field1": {"type": "string"}, + "field2": {"type": "uuid"}, + }, + }, + } + } + + @bentoml.api(input=JsonInput(request_schema=CUSTOM_SCHEMA), batch=True) + def customezed_schema(self, input_datas): + return input_datas + @bentoml.api(input=JsonInput(), batch=True) def predict_strict_json(self, input_datas, tasks: Sequence[InferenceTask] = None): filtered_jsons = [] diff --git a/tests/integration/projects/general/tests/test_meta.py b/tests/integration/projects/general/tests/test_meta.py index bc601f1bce3..9f599bb4cb6 100644 --- a/tests/integration/projects/general/tests/test_meta.py +++ b/tests/integration/projects/general/tests/test_meta.py @@ -34,3 +34,17 @@ def path_in_docs(response_body): data=json.dumps("hello"), assert_data=bytes('"hello"', 'ascii'), ) + + +@pytest.mark.asyncio +async def test_customized_request_schema(host): + def has_customized_schema(doc_bytes): + json_str = doc_bytes.decode() + return "field1" in json_str + + await pytest.assert_request( + "GET", + f"http://{host}/docs.json", + headers=(("Content-Type", "application/json"),), + assert_data=has_customized_schema, + ) diff --git a/tests/integration/projects/general/tests/test_microbatch.py b/tests/integration/projects/general/tests/test_microbatch.py index 7910b26e406..3a1a5e9ea76 100644 --- a/tests/integration/projects/general/tests/test_microbatch.py +++ b/tests/integration/projects/general/tests/test_microbatch.py @@ -33,6 +33,15 @@ async def test_slow_server(host): await asyncio.gather(*tasks) assert time.time() - time_start < 12 + +@pytest.mark.asyncio +async def test_fast_server(host): + if not pytest.enable_microbatch: + pytest.skip() + + A, B = 0.0002, 0.01 + data = '{"a": %s, "b": %s}' % (A, B) + req_count = 100 tasks = tuple( pytest.assert_request( @@ -46,17 +55,8 @@ async def test_slow_server(host): ) await asyncio.gather(*tasks) - -@pytest.mark.asyncio -async def test_fast_server(host): - if not pytest.enable_microbatch: - pytest.skip() - - A, B = 0.0002, 0.01 - data = '{"a": %s, "b": %s}' % (A, B) - time_start = time.time() - req_count = 500 + req_count = 200 tasks = tuple( pytest.assert_request( "POST", @@ -70,4 +70,4 @@ async def test_fast_server(host): for i in range(req_count) ) await asyncio.gather(*tasks) - assert time.time() - time_start < 5 + assert time.time() - time_start < 2 diff --git a/tests/integration/utils.py b/tests/integration/utils.py index 2fd91589fab..c6e076e960d 100644 --- a/tests/integration/utils.py +++ b/tests/integration/utils.py @@ -16,20 +16,21 @@ def _wait_until_api_server_ready(host_url, timeout, container=None, check_interval=1): start_time = time.time() + proxy_handler = urllib.request.ProxyHandler({}) + opener = urllib.request.build_opener(proxy_handler) + ex = None while time.time() - start_time < timeout: try: - if ( - urllib.request.urlopen(f'http://{host_url}/healthz', timeout=1).status - == 200 - ): - break + if opener.open(f'http://{host_url}/healthz', timeout=1).status == 200: + return elif container.status != "running": break else: logger.info("Waiting for host %s to be ready..", host_url) time.sleep(check_interval) except Exception as e: # pylint:disable=broad-except - logger.info(f"'{e}', retrying to connect to the host {host_url}...") + logger.info(f"retrying to connect to the host {host_url}...") + ex = e time.sleep(check_interval) finally: if container: @@ -40,7 +41,8 @@ def _wait_until_api_server_ready(host_url, timeout, container=None, check_interv logger.info(f">>> {log_record}") else: raise AssertionError( - f"Timed out waiting {timeout} seconds for Server {host_url} to be ready" + f"Timed out waiting {timeout} seconds for Server {host_url} to be ready, " + f"exception: {ex}" ) @@ -148,6 +150,8 @@ def print_log(p): ) as p: host_url = f"127.0.0.1:{port}" threading.Thread(target=print_log, args=(p,), daemon=True).start() - _wait_until_api_server_ready(host_url, timeout=timeout) - yield host_url - p.terminate() + try: + _wait_until_api_server_ready(host_url, timeout=timeout) + yield host_url + finally: + p.terminate()