diff --git a/qa/L0_python_api/test_kserve.py b/qa/L0_python_api/test_kserve.py index b810a38125..f5c0ec25bb 100644 --- a/qa/L0_python_api/test_kserve.py +++ b/qa/L0_python_api/test_kserve.py @@ -452,10 +452,12 @@ def test_restrict_inference(self, frontend, client_type, url, key_prefix): server = utils.setup_server() # Specifying restricted features that restricts inference. + infer_key, infer_value = "infer-key", "infer-value" + rf = RestrictedFeatures() rf.create_feature_group( - key="infer-key", - value="infer-value", + key=infer_key, + value=infer_value, features=[Feature.INFERENCE], ) @@ -466,13 +468,21 @@ def test_restrict_inference(self, frontend, client_type, url, key_prefix): headers = {key_prefix + "infer-key": "infer-value"} assert utils.send_and_test_inference_identity(client_type, url, headers) - # Invalid headers sent with inference request - headers = {key_prefix + "fake-key": "fake-value"} - with pytest.raises( - InferenceServerException, - match=f"expecting header '{key_prefix}infer-key'", - ): - utils.send_and_test_inference_identity(client_type, url, headers) + # Combinations of Invalid (or no) headers sent with inference request + invalid_key_value = {key_prefix + "fake-key": "fake-value"} + invalid_value = {key_prefix + infer_key: "fake-value"} + error_msg = f"expecting header '{key_prefix}infer-key'" + + for header, err_msg in [ + (invalid_key_value, error_msg), + (invalid_value, error_msg), + (None, error_msg), + ]: + with pytest.raises( + InferenceServerException, + match=err_msg, + ): + utils.send_and_test_inference_identity(client_type, url, header) utils.teardown_service(service) utils.teardown_server(server)