Skip to content

Commit

Permalink
Update parser.py
Browse files Browse the repository at this point in the history
  • Loading branch information
R-N committed Dec 14, 2023
1 parent 4dc10d7 commit 4547279
Showing 1 changed file with 26 additions and 3 deletions.
29 changes: 26 additions & 3 deletions ml_draftpick_dss/parsing/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from .classifier import MatchResultClassifier, HeroIconClassifier, MedalClassifier, ScreenshotClassifier, HERO_ICON_IMG_SIZE, MEDAL_LABELS
from .grouping import infer_ss_type, read_opening_failure, check_opening_failure
from .util import inference_save_path, read_save_path, save_inference, mkdir, exception_message
import time

BAD_FILE_EXCEPTIONS = [
"HISTORY",
Expand Down Expand Up @@ -81,7 +82,9 @@ def __init__(
hero_icon_classifier,
medal_classifier,
ocr=None, img_size=None,
inference_save_dir="inferences"
inference_save_dir="inferences",
forgive_afk=False,
forgive_invalid=False,
):
self.input_dir = input_dir
assert isinstance(ss_classifier, ScreenshotClassifier)
Expand All @@ -100,6 +103,16 @@ def __init__(
self.scaler = None
self.ocr = ocr or OCR(has_number=False)

self.total_inference_time = 0
self.total_parse_time = 0
self.forgive_afk = forgive_afk
self.forgive_invalid = forgive_invalid
self.n_size = 0

@property
def total_time(self):
return self.total_inference_time + self.total_parse_time

def input_dir_player(self, player_name):
return os.path.join(self.input_dir, player_name)

Expand Down Expand Up @@ -154,15 +167,19 @@ def infer(self, ss_path, player_name, throw=False, return_img=False):
opening_failure = self.check_opening_failure(opening_failure_text)
assert ((not throw) or (not opening_failure)), f"OPENING_FAILURE: {ss_path}"

t0 = time.time()

match_result, match_result_img = self.infer_match_result(img, bgr=False)
assert ((not throw) or (match_result != "Invalid")), f"INVALID: {ss_path}"
assert (self.forgive_invalid or (not throw) or (match_result != "Invalid")), f"INVALID: {ss_path}"

medals, medals_img = self.infer_medals(img, bgr=False)
assert ((not throw) or ("AFK" not in (medals[0] + medals[1]))), f"AFK: {ss_path}; {medals}"
assert (self.forgive_afk or (not throw) or ("AFK" not in (medals[0] + medals[1]))), f"AFK: {ss_path}; {medals}"

heroes, heroes_img = self.infer_heroes(img, bgr=False)
assert ((not throw) or (len(set(heroes[0] + heroes[1])) == 10)), f"DOUBLE: {ss_path}; {heroes}"

t1 = time.time()

try:
battle_id, battle_id_img = self.read_battle_id(img, bgr=False, throw=throw)
match_duration, match_duration_img = self.read_match_duration(img, bgr=False, throw=throw)
Expand All @@ -173,6 +190,12 @@ def infer(self, ss_path, player_name, throw=False, return_img=False):
err_type, err_detail = message.split(":", maxsplit=1)
new_message = f"{err_type}: {relpath}; {err_detail.strip()}"
raise AssertionError(new_message)

t2 = time.time()

self.n_size += 1
self.total_inference_time += t1 - t0
self.total_parse_time += t2 - t1

assert ((not throw) or (0 == len([1 for i in range(2) for s in scores[i] if s >= 17.0]))), f"OVERSCORE: {ss_path}; {scores}"

Expand Down

0 comments on commit 4547279

Please sign in to comment.