diff --git a/api/api.py b/api/api.py index 7138909..404bd94 100644 --- a/api/api.py +++ b/api/api.py @@ -1,21 +1,22 @@ +import sqlite3 +from dataclasses import dataclass +from enum import Enum +from pathlib import Path + +import numpy as np import torch +from download import download_file, split_and_download +from file_links import database as database_link +from file_links import links, tourn_links from torch.utils.data import Dataset -from pathlib import Path +from utils import find_closest_indices + from sc2_replay_reader import ( GAME_INFO_FILE, ReplayDataAllDatabase, ReplayDataAllParser, Result, ) -import sqlite3 -from typing import List -import numpy as np -from dataclasses import dataclass -from yeet import download_file, split_and_download -from file_links import links, tourn_links, database as database_link -from utils import find_closest_indices - -from enum import Enum @dataclass @@ -49,7 +50,7 @@ def __init__( features: set[str] | None = None, database: Path = Path("sc2_dataset.db"), timepoints: TimeRange = TimeRange(0, 30, 0.5), - sql_filters: List[str] | None = None, + sql_filters: list[str] | None = None, ): """ Args: @@ -110,7 +111,7 @@ def __init__( _loop_per_min = 22.4 * 60 self._target_game_loops = (timepoints.arange() * _loop_per_min).to(torch.int) - def load_database(self, database: Path, sql_filters: List[str] | None = None): + def load_database(self, database: Path, sql_filters: list[str] | None = None): self.database = database self.sql_filters = sql_filters sql_filter_string = ( @@ -184,7 +185,7 @@ def getitem(self, file_name: Path, db_index: int): outputs = { "win": torch.as_tensor( - self.parser.data.playerResult == Result.Win, dtype=torch.float32 + self.parser.info.playerResult == Result.Win, dtype=torch.float32 ), "valid": torch.cat([torch.tensor([True]), sample_indices != -1]), } diff --git a/api/yeet.py b/api/download.py similarity index 82% rename from api/yeet.py rename to api/download.py index 01da2b8..2c3f922 100644 --- a/api/yeet.py +++ b/api/download.py @@ -1,26 +1,27 @@ -import requests -from urllib.parse import unquote +import random from pathlib import Path +from urllib.parse import unquote + +import requests from tqdm import tqdm -import random -from typing import List, Tuple -def get_filename_from_url(url): - response = requests.head(url) +def get_filename_from_url(url: str): + """Extract the filename to be downloaded from a url""" + response = requests.head(url, timeout=60) content_disposition = response.headers.get("Content-Disposition") if content_disposition and "filename=" in content_disposition: return unquote(content_disposition.split("filename=")[1].strip('"')) - else: - return url.split("/")[-1] + return url.split("/")[-1] -def download_file(url, destination=Path(".")): +def download_file(url: str, destination: Path = Path().cwd()): + """Download file from url to destination""" file_name = get_filename_from_url(url) destination /= file_name print(f"downloading {destination}") - response = requests.get(url, stream=True) + response = requests.get(url, stream=True, timeout=60) total_size = int(response.headers.get("content-length", 0)) block_size = 1024 # 1 Kibibyte @@ -34,10 +35,9 @@ def download_file(url, destination=Path(".")): def split_and_download( - links_and_sizes: List[Tuple[str, float]], percentage_split: float, folder: Path + links_and_sizes: list[tuple[str, float]], percentage_split: float, folder: Path ): - random.shuffle(links_and_sizes) - + """Download files from public repository based on percentage split to destination folder""" total_size = sum(size for _, size in links_and_sizes) target_size = total_size * percentage_split diff --git a/api/utils.py b/api/utils.py index 78c6ee9..6ad6d6a 100644 --- a/api/utils.py +++ b/api/utils.py @@ -1,11 +1,11 @@ from typing import Sequence + import torch def find_closest_indices(options: Sequence[int], targets: Sequence[int]): """ Find the closest option corresponding to a target, if there is no match, place -1 - TODO Convert this to cpp """ tgt_idx = 0 nearest = torch.full([len(targets)], -1, dtype=torch.int32)