Skip to content

Commit

Permalink
modify example api, but needs to be re-written
Browse files Browse the repository at this point in the history
Signed-off-by: Bryce Ferenczi <[email protected]>
  • Loading branch information
5had3z committed May 29, 2024
1 parent 3d6df43 commit 4cf63c3
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 27 deletions.
27 changes: 14 additions & 13 deletions api/api.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,22 @@
import sqlite3
from dataclasses import dataclass
from enum import Enum
from pathlib import Path

import numpy as np
import torch
from download import download_file, split_and_download
from file_links import database as database_link
from file_links import links, tourn_links
from torch.utils.data import Dataset
from pathlib import Path
from utils import find_closest_indices

from sc2_replay_reader import (
GAME_INFO_FILE,
ReplayDataAllDatabase,
ReplayDataAllParser,
Result,
)
import sqlite3
from typing import List
import numpy as np
from dataclasses import dataclass
from yeet import download_file, split_and_download
from file_links import links, tourn_links, database as database_link
from utils import find_closest_indices

from enum import Enum


@dataclass
Expand Down Expand Up @@ -49,7 +50,7 @@ def __init__(
features: set[str] | None = None,
database: Path = Path("sc2_dataset.db"),
timepoints: TimeRange = TimeRange(0, 30, 0.5),
sql_filters: List[str] | None = None,
sql_filters: list[str] | None = None,
):
"""
Args:
Expand Down Expand Up @@ -110,7 +111,7 @@ def __init__(
_loop_per_min = 22.4 * 60
self._target_game_loops = (timepoints.arange() * _loop_per_min).to(torch.int)

def load_database(self, database: Path, sql_filters: List[str] | None = None):
def load_database(self, database: Path, sql_filters: list[str] | None = None):
self.database = database
self.sql_filters = sql_filters
sql_filter_string = (
Expand Down Expand Up @@ -184,7 +185,7 @@ def getitem(self, file_name: Path, db_index: int):

outputs = {
"win": torch.as_tensor(
self.parser.data.playerResult == Result.Win, dtype=torch.float32
self.parser.info.playerResult == Result.Win, dtype=torch.float32
),
"valid": torch.cat([torch.tensor([True]), sample_indices != -1]),
}
Expand Down
26 changes: 13 additions & 13 deletions api/yeet.py → api/download.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,27 @@
import requests
from urllib.parse import unquote
import random
from pathlib import Path
from urllib.parse import unquote

import requests
from tqdm import tqdm
import random
from typing import List, Tuple


def get_filename_from_url(url):
response = requests.head(url)
def get_filename_from_url(url: str):
"""Extract the filename to be downloaded from a url"""
response = requests.head(url, timeout=60)
content_disposition = response.headers.get("Content-Disposition")

if content_disposition and "filename=" in content_disposition:
return unquote(content_disposition.split("filename=")[1].strip('"'))
else:
return url.split("/")[-1]
return url.split("/")[-1]


def download_file(url, destination=Path(".")):
def download_file(url: str, destination: Path = Path().cwd()):
"""Download file from url to destination"""
file_name = get_filename_from_url(url)
destination /= file_name
print(f"downloading {destination}")
response = requests.get(url, stream=True)
response = requests.get(url, stream=True, timeout=60)

total_size = int(response.headers.get("content-length", 0))
block_size = 1024 # 1 Kibibyte
Expand All @@ -34,10 +35,9 @@ def download_file(url, destination=Path(".")):


def split_and_download(
links_and_sizes: List[Tuple[str, float]], percentage_split: float, folder: Path
links_and_sizes: list[tuple[str, float]], percentage_split: float, folder: Path
):
random.shuffle(links_and_sizes)

"""Download files from public repository based on percentage split to destination folder"""
total_size = sum(size for _, size in links_and_sizes)
target_size = total_size * percentage_split

Expand Down
2 changes: 1 addition & 1 deletion api/utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from typing import Sequence

import torch


def find_closest_indices(options: Sequence[int], targets: Sequence[int]):
"""
Find the closest option corresponding to a target, if there is no match, place -1
TODO Convert this to cpp
"""
tgt_idx = 0
nearest = torch.full([len(targets)], -1, dtype=torch.int32)
Expand Down

0 comments on commit 4cf63c3

Please sign in to comment.