diff --git a/tests/test_middleware.py b/tests/test_middleware.py index 55f42d57..c5c5b800 100644 --- a/tests/test_middleware.py +++ b/tests/test_middleware.py @@ -12,15 +12,16 @@ # License for the specific language governing permissions and limitations # under the License. -from pytest import fixture, mark -from supertokens_python import InputAppInfo, SupertokensConfig, init -from supertokens_python.framework.fastapi import get_middleware -from supertokens_python.recipe import emailpassword, session, passwordless import json from fastapi import FastAPI +from pytest import fixture, mark + +from supertokens_python import InputAppInfo, SupertokensConfig, init +from supertokens_python.framework.fastapi import get_middleware +from supertokens_python.recipe import emailpassword, passwordless, session from tests.testclient import TestClientWithNoCookieJar as TestClient -from tests.utils import clean_st, reset, setup_st, start_st, sign_up_request +from tests.utils import clean_st, reset, setup_st, sign_up_request, start_st def setup_function(_): @@ -61,7 +62,9 @@ async def test_rid_with_session_and_non_existent_api_in_session_recipe_still_hit start_st() response = driver_config_client.post(url="/auth/signin", headers={"rid": "session"}) - assert response.status_code == 400 + assert response.status_code == 200 + dict_response = json.loads(response.text) + assert dict_response["status"] == "FIELD_ERROR" @mark.asyncio @@ -84,7 +87,9 @@ async def test_no_rid_with_existent_API_does_not_give_404( start_st() response = driver_config_client.post(url="/auth/signin") - assert response.status_code == 400 + assert response.status_code == 200 + dict_response = json.loads(response.text) + assert dict_response["status"] == "FIELD_ERROR" @mark.asyncio @@ -109,7 +114,9 @@ async def test_rid_as_anticsrf_with_existent_API_does_not_give_404( response = driver_config_client.post( url="/auth/signin", headers={"rid": "anti-csrf"} ) - assert response.status_code == 400 + assert response.status_code == 200 + dict_response = json.loads(response.text) + assert dict_response["status"] == "FIELD_ERROR" @mark.asyncio @@ -132,7 +139,9 @@ async def test_random_rid_with_existent_API_does_hits_api( start_st() response = driver_config_client.post(url="/auth/signin", headers={"rid": "random"}) - assert response.status_code == 400 + assert response.status_code == 200 + dict_response = json.loads(response.text) + assert dict_response["status"] == "FIELD_ERROR" @mark.asyncio