From 5cb90fa194d55c86fae640178731dfd3c6aea617 Mon Sep 17 00:00:00 2001 From: Askaholic Date: Sat, 18 Sep 2021 15:14:56 -0800 Subject: [PATCH] Fix misc errors that show up in the logs (#830) * Clean up formatting * Log decode errors in JsonStats as WARNING instead of ERROR * Ensure that the server will start, even if hydra is unreachable * Log on INFO level when game setup times out * Handle exceptions when processing game stats in a task * Correctly check error text during deadlock handling * Check gamestate before launching * Reduce pop timer log level to info --- server/db/__init__.py | 3 +- server/gameconnection.py | 24 ++++-- server/games/game.py | 19 ++--- server/ladder_service.py | 13 ++- server/matchmaker/pop_timer.py | 2 +- server/oauth_service.py | 56 +++++++++++-- server/stats/game_stats_service.py | 24 +++++- tests/integration_tests/test_oauth_session.py | 1 - tests/unit_tests/test_database.py | 41 +++++++++ tests/unit_tests/test_gameconnection.py | 84 +++++++++++++++---- tests/unit_tests/test_oauth_service.py | 35 ++++++++ tests/unit_tests/test_oauth_session.py | 2 - 12 files changed, 254 insertions(+), 50 deletions(-) create mode 100644 tests/unit_tests/test_database.py create mode 100644 tests/unit_tests/test_oauth_service.py diff --git a/server/db/__init__.py b/server/db/__init__.py index 16c78ab73..6d49e53ee 100644 --- a/server/db/__init__.py +++ b/server/db/__init__.py @@ -57,7 +57,8 @@ async def deadlock_retry_execute(conn, *args, max_attempts=3): try: return await conn.execute(*args) except OperationalError as e: - if any(msg in e.message for msg in ( + error_text = str(e) + if any(msg in error_text for msg in ( "Deadlock found", "Lock wait timeout exceeded" )): diff --git a/server/gameconnection.py b/server/gameconnection.py index 66af0db86..7f03214e0 100644 --- a/server/gameconnection.py +++ b/server/gameconnection.py @@ -4,6 +4,7 @@ import asyncio import contextlib +import json from typing import Any, List from sqlalchemy import select, text @@ -362,7 +363,14 @@ async def handle_operation_complete( self.game.leaderboard_saved = True async def handle_json_stats(self, stats: str): - self.game.report_army_stats(stats) + try: + self.game.report_army_stats(stats) + except json.JSONDecodeError: + self._logger.warning( + "Malformed game stats reported by %s: '...%s'", + self._player.login, + stats[-20:] + ) async def handle_enforce_rating(self): self.game.enforce_rating = True @@ -471,11 +479,15 @@ async def handle_game_state(self, state: str): if self.player.state != PlayerState.HOSTING: return - self._logger.info( - "Launching game %s in state %s", - self.game, - self.game.state - ) + if self.game.state is not GameState.LOBBY: + self._logger.warning( + "Trying to launch game %s in invalid state %s", + self.game, + self.game.state + ) + return + + self._logger.info("Launching game %s", self.game) await self.game.launch() diff --git a/server/games/game.py b/server/games/game.py index c917c9cfc..70412ffe6 100644 --- a/server/games/game.py +++ b/server/games/game.py @@ -117,9 +117,8 @@ def __init__( "RestrictedCategories": 0 } self.mods = {} - loop = asyncio.get_event_loop() - self._is_hosted = loop.create_future() - self._launch_fut = loop.create_future() + self._hosted_event = asyncio.Event() + self._launched_event = asyncio.Event() self._logger.debug("%s created", self) asyncio.get_event_loop().create_task(self.timeout_game(setup_timeout)) @@ -127,10 +126,7 @@ def __init__( async def timeout_game(self, timeout: int = 60): await asyncio.sleep(timeout) if self.state is GameState.INITIALIZING: - self._is_hosted.set_exception( - asyncio.TimeoutError("Game setup timed out") - ) - self._logger.debug("Game setup timed out.. Cancelling game") + self._logger.debug("Game setup timed out, cancelling game") await self.on_game_end() @property @@ -272,17 +268,16 @@ def get_team_sets(self) -> List[Set[Player]]: async def wait_hosted(self, timeout: float): return await asyncio.wait_for( - asyncio.shield(self._is_hosted), + self._hosted_event.wait(), timeout=timeout ) def set_hosted(self): - if not self._is_hosted.done(): - self._is_hosted.set_result(None) + self._hosted_event.set() async def wait_launched(self, timeout: float): return await asyncio.wait_for( - asyncio.shield(self._launch_fut), + self._launched_event.wait(), timeout=timeout ) @@ -707,7 +702,7 @@ async def launch(self): await self.on_game_launched() await self.validate_game_settings() - self._launch_fut.set_result(None) + self._launched_event.set() self._logger.info("Game launched") async def on_game_launched(self): diff --git a/server/ladder_service.py b/server/ladder_service.py index a89714e3b..4049cf5f8 100644 --- a/server/ladder_service.py +++ b/server/ladder_service.py @@ -466,11 +466,18 @@ def game_options(player: Player) -> GameLaunchOptions: if guest.lobby_connection is not None ]) await game.wait_launched(60 + 10 * len(all_guests)) - self._logger.debug("Ladder game launched successfully") - except Exception: + self._logger.debug("Ladder game launched successfully %s", game) + except Exception as e: + if isinstance(e, asyncio.TimeoutError): + self._logger.info( + "Ladder game failed to start! %s setup timed out", + game + ) + else: + self._logger.exception("Ladder game failed to start %s", game) + if game: await game.on_game_end() - self._logger.exception("Failed to start ladder game!") msg = {"command": "match_cancelled"} for player in all_players: diff --git a/server/matchmaker/pop_timer.py b/server/matchmaker/pop_timer.py index a5d10e981..b4f83007c 100644 --- a/server/matchmaker/pop_timer.py +++ b/server/matchmaker/pop_timer.py @@ -74,7 +74,7 @@ def time_until_next_pop(self, num_queued: int, time_queued: float) -> float: # Obtained by solving $ NUM_PLAYERS = rate * time $ for time. next_pop_time = desired_players * total_times / total_players if next_pop_time > config.QUEUE_POP_TIME_MAX: - self._logger.warning( + self._logger.info( "Required time (%.2fs) for %s is larger than max pop time (%ds). " "Consider increasing the max pop time", next_pop_time, self.queue.name, config.QUEUE_POP_TIME_MAX diff --git a/server/oauth_service.py b/server/oauth_service.py index 6e7de1861..47bab9c3e 100644 --- a/server/oauth_service.py +++ b/server/oauth_service.py @@ -1,3 +1,5 @@ +import time + import aiocron import aiohttp import jwt @@ -6,6 +8,7 @@ from server.config import config +from .asyncio_extensions import synchronizedmethod from .core import Service from .decorators import with_logger from .exceptions import AuthenticationError @@ -19,6 +22,7 @@ class OAuthService(Service, name="oauth_service"): def __init__(self): self.public_keys = {} + self._last_key_fetch_time = None async def initialize(self) -> None: await self.retrieve_public_keys() @@ -28,25 +32,59 @@ async def initialize(self) -> None: "*/10 * * * *", func=self.retrieve_public_keys ) + @synchronizedmethod + async def get_public_keys(self) -> dict: + """ + Return cached keys, or fetch them if they're missing + """ + if not self.public_keys: + # Rate limit requests so we don't spam the endpoint when it's down + if ( + not self._last_key_fetch_time or + time.monotonic() - self._last_key_fetch_time > 5 + ): + await self.retrieve_public_keys() + + if not self.public_keys: + raise RuntimeError("jwks could not be retrieved") + + return self.public_keys + async def retrieve_public_keys(self) -> None: """ Get the latest jwks from the hydra endpoint """ - async with aiohttp.ClientSession(raise_for_status=True) as session: - async with session.get(config.HYDRA_JWKS_URI) as resp: - jwks = await resp.json() - self.public_keys = { - jwk["kid"]: RSAAlgorithm.from_jwk(jwk) - for jwk in jwks["keys"] - } + self._last_key_fetch_time = time.monotonic() + try: + async with aiohttp.ClientSession(raise_for_status=True) as session: + async with session.get(config.HYDRA_JWKS_URI) as resp: + jwks = await resp.json() + self.public_keys = { + jwk["kid"]: RSAAlgorithm.from_jwk(jwk) + for jwk in jwks["keys"] + } + except Exception: + self._logger.exception( + "Unable to retrieve jwks, token login will be unavailable!" + ) async def get_player_id_from_token(self, token: str) -> int: """ Decode the JWT to get the player_id """ + # Ensures that if we're missing the jwks we will try to fetch them on + # each new login request. This way our login functionality will be + # restored as soon as possible + keys = await self.get_public_keys() try: kid = jwt.get_unverified_header(token)["kid"] - key = self.public_keys[kid] - return int(jwt.decode(token, key=key, algorithms="RS256", options={"verify_aud": False})["sub"]) + key = keys[kid] + decoded = jwt.decode( + token, + key=key, + algorithms="RS256", + options={"verify_aud": False} + ) + return int(decoded["sub"]) except (InvalidTokenError, KeyError, ValueError): raise AuthenticationError("Token signature was invalid", "token") diff --git a/server/stats/game_stats_service.py b/server/stats/game_stats_service.py index 895b5517b..e84e51137 100644 --- a/server/stats/game_stats_service.py +++ b/server/stats/game_stats_service.py @@ -18,7 +18,29 @@ def __init__(self, event_service: EventService, achievement_service: Achievement self._event_service = event_service self._achievement_service = achievement_service - async def process_game_stats(self, player: Player, game: Game, army_stats_list: List): + async def process_game_stats( + self, + player: Player, + game: Game, + army_stats_list: List + ): + try: + await self._process_game_stats(player, game, army_stats_list) + except KeyError as e: + self._logger.info("Malformed game stats. KeyError: %s", e) + except Exception: + self._logger.exception( + "Error processing game stats for %s in game %d", + player.login, + game.id + ) + + async def _process_game_stats( + self, + player: Player, + game: Game, + army_stats_list: List + ): stats = None number_of_humans = 0 highest_score = 0 diff --git a/tests/integration_tests/test_oauth_session.py b/tests/integration_tests/test_oauth_session.py index 8cca968e8..671e70bf5 100644 --- a/tests/integration_tests/test_oauth_session.py +++ b/tests/integration_tests/test_oauth_session.py @@ -1,4 +1,3 @@ -import asyncio import os import pytest diff --git a/tests/unit_tests/test_database.py b/tests/unit_tests/test_database.py new file mode 100644 index 000000000..4511b8082 --- /dev/null +++ b/tests/unit_tests/test_database.py @@ -0,0 +1,41 @@ +from unittest import mock + +import asynctest +import pymysql +import pytest + +from server.db import deadlock_retry_execute +from tests.utils import fast_forward + + +@pytest.mark.asyncio +@fast_forward(10) +async def test_deadlock_retry_execute(): + mock_conn = mock.Mock() + mock_conn.execute = asynctest.CoroutineMock( + side_effect=pymysql.err.OperationalError(-1, "Deadlock found") + ) + + with pytest.raises(pymysql.err.OperationalError): + await deadlock_retry_execute(mock_conn, "foo") + + assert mock_conn.execute.call_count == 3 + + +@pytest.mark.asyncio +@fast_forward(10) +async def test_deadlock_retry_execute_success(): + call_count = 0 + + async def execute(*args): + nonlocal call_count + call_count += 1 + if call_count <= 1: + raise pymysql.err.OperationalError(-1, "Deadlock found") + + mock_conn = mock.Mock() + mock_conn.execute = asynctest.CoroutineMock(side_effect=execute) + + await deadlock_retry_execute(mock_conn, "foo") + + assert mock_conn.execute.call_count == 2 diff --git a/tests/unit_tests/test_gameconnection.py b/tests/unit_tests/test_gameconnection.py index 2fcc7a6c0..a076b0df7 100644 --- a/tests/unit_tests/test_gameconnection.py +++ b/tests/unit_tests/test_gameconnection.py @@ -27,8 +27,7 @@ @pytest.fixture def real_game(event_loop, database, game_service, game_stats_service): - game = Game(42, database, game_service, game_stats_service) - yield game + return Game(42, database, game_service, game_stats_service) def assert_message_sent(game_connection: GameConnection, command, args): @@ -243,12 +242,28 @@ async def test_handle_action_GameState_launching_calls_launch( game_connection.player = players.hosting game_connection.game = game game.launch = CoroutineMock() + game.state = GameState.LOBBY await game_connection.handle_action("GameState", ["Launching"]) game.launch.assert_any_call() +async def test_handle_action_GameState_launching_when_ended( + game: Game, + game_connection: GameConnection, + players +): + game_connection.player = players.hosting + game_connection.game = game + game.launch = CoroutineMock() + game.state = GameState.ENDED + + await game_connection.handle_action("GameState", ["Launching"]) + + game.launch.assert_not_called() + + async def test_handle_action_GameState_ended_calls_on_connection_lost( game_connection: GameConnection ): @@ -307,6 +322,7 @@ async def test_handle_action_GameMods_post_launch_updates_played_cache( database ): game.launch = CoroutineMock() + game.state = GameState.LOBBY game.remove_game_connection = CoroutineMock() await game_connection.handle_action("GameMods", ["uids", "foo bar EA040F8E-857A-4566-9879-0D37420A5B9D"]) @@ -318,7 +334,10 @@ async def test_handle_action_GameMods_post_launch_updates_played_cache( assert 2 == row[0] -async def test_handle_action_AIOption(game: Game, game_connection: GameConnection): +async def test_handle_action_AIOption( + game: Game, + game_connection: GameConnection +): await game_connection.handle_action("AIOption", ["QAI", "StartSpot", 1]) game.set_ai_option.assert_called_once_with("QAI", "StartSpot", 1) @@ -333,7 +352,10 @@ async def test_handle_action_AIOption_not_host( game.set_ai_option.assert_not_called() -async def test_handle_action_ClearSlot(game: Game, game_connection: GameConnection): +async def test_handle_action_ClearSlot( + game: Game, + game_connection: GameConnection +): await game_connection.handle_action("ClearSlot", [1]) game.clear_slot.assert_called_once_with(1) await game_connection.handle_action("ClearSlot", ["1"]) @@ -350,14 +372,21 @@ async def test_handle_action_ClearSlot_not_host( game.clear_slot.assert_not_called() -async def test_handle_action_GameResult_calls_add_result(game: Game, game_connection: GameConnection): +async def test_handle_action_GameResult_calls_add_result( + game: Game, + game_connection: GameConnection +): game_connection.connect_to_host = CoroutineMock() await game_connection.handle_action("GameResult", [0, "score -5"]) game.add_result.assert_called_once_with(game_connection.player.id, 0, "score", -5, frozenset()) -async def test_cannot_parse_game_results(game: Game, game_connection: GameConnection, caplog): +async def test_cannot_parse_game_results( + game: Game, + game_connection: GameConnection, + caplog +): game_connection.connect_to_host = CoroutineMock() with caplog.at_level(logging.WARNING): @@ -366,7 +395,10 @@ async def test_cannot_parse_game_results(game: Game, game_connection: GameConnec assert 'Invalid result' in caplog.records[0].message -async def test_handle_action_GameOption(game: Game, game_connection: GameConnection): +async def test_handle_action_GameOption( + game: Game, + game_connection: GameConnection +): game.gameOptions = {"AIReplacement": "Off"} await game_connection.handle_action("GameOption", ["Victory", "sandbox"]) assert game.gameOptions["Victory"] == Victory.SANDBOX @@ -394,18 +426,35 @@ async def test_handle_action_GameOption_not_host( assert game.gameOptions == {"Victory": "asdf"} -async def test_json_stats(game_connection: GameConnection, game_stats_service, players, game): - game_stats_service.process_game_stats = mock.Mock() +async def test_json_stats( + real_game: Game, + game_connection: GameConnection, +): + game_connection.game = real_game await game_connection.handle_action("JsonStats", ['{"stats": {}}']) - game.report_army_stats.assert_called_once_with('{"stats": {}}') -async def test_handle_action_EnforceRating(game: Game, game_connection: GameConnection): +async def test_json_stats_malformed( + real_game: Game, + game_connection: GameConnection, +): + game_connection.game = real_game + await game_connection.handle_action("JsonStats", ['{"stats": {}']) + + +async def test_handle_action_EnforceRating( + game: Game, + game_connection: GameConnection +): await game_connection.handle_action("EnforceRating", []) assert game.enforce_rating is True -async def test_handle_action_TeamkillReport(game: Game, game_connection: GameConnection, database): +async def test_handle_action_TeamkillReport( + game: Game, + game_connection: GameConnection, + database +): game.launch = CoroutineMock() await game_connection.handle_action("TeamkillReport", ["200", "2", "Dostya", "3", "Rhiza"]) @@ -416,7 +465,10 @@ async def test_handle_action_TeamkillReport(game: Game, game_connection: GameCon assert report is None -async def test_handle_action_TeamkillHappened(game: Game, game_connection: GameConnection, database): +async def test_handle_action_TeamkillHappened( + game: Game, + game_connection: GameConnection, database +): game.launch = CoroutineMock() await game_connection.handle_action("TeamkillHappened", ["200", "2", "Dostya", "3", "Rhiza"]) @@ -427,7 +479,11 @@ async def test_handle_action_TeamkillHappened(game: Game, game_connection: GameC assert game.id == row[0] -async def test_handle_action_TeamkillHappened_AI(game: Game, game_connection: GameConnection, database): +async def test_handle_action_TeamkillHappened_AI( + game: Game, + game_connection: GameConnection, + database +): # Should fail with a sql constraint error if this isn't handled correctly game_connection.abort = CoroutineMock() await game_connection.handle_action("TeamkillHappened", ["200", 0, "Dostya", "0", "Rhiza"]) diff --git a/tests/unit_tests/test_oauth_service.py b/tests/unit_tests/test_oauth_service.py new file mode 100644 index 000000000..ec09fddfb --- /dev/null +++ b/tests/unit_tests/test_oauth_service.py @@ -0,0 +1,35 @@ +import pytest +from asynctest import CoroutineMock + +from server import OAuthService + + +@pytest.fixture +def oauth_service(): + return OAuthService() + + +@pytest.mark.asyncio +async def test_get_public_keys(oauth_service): + def set_public_keys(): + oauth_service.public_keys = {"any": "value"} + + oauth_service.retrieve_public_keys = CoroutineMock( + side_effect=set_public_keys + ) + + public_keys = await oauth_service.get_public_keys() + + assert public_keys == {"any": "value"} + oauth_service.retrieve_public_keys.assert_called_once() + + +@pytest.mark.asyncio +async def test_get_public_keys_cached(oauth_service): + oauth_service.public_keys = {"any": "value"} + oauth_service.retrieve_public_keys = CoroutineMock() + + public_keys = await oauth_service.get_public_keys() + + assert public_keys == {"any": "value"} + oauth_service.retrieve_public_keys.assert_not_called() diff --git a/tests/unit_tests/test_oauth_session.py b/tests/unit_tests/test_oauth_session.py index 8d92093a2..a84cf3cf3 100644 --- a/tests/unit_tests/test_oauth_session.py +++ b/tests/unit_tests/test_oauth_session.py @@ -1,5 +1,3 @@ -import time - import pytest from oauthlib.oauth2.rfc6749.errors import ( InsecureTransportError,