Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

データセットを作る. #1129

Merged
merged 10 commits into from
Aug 28, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,7 @@ dist
.pytest_cache
.cache
.ipynb_checkpoints
workspace/suphnx-reward-shaping/recourses/*
.DS_Store
.vscode/
.python_versions
Binary file added workspace/.DS_Store
Binary file not shown.

Large diffs are not rendered by default.

Large diffs are not rendered by default.

Large diffs are not rendered by default.

Large diffs are not rendered by default.

20 changes: 20 additions & 0 deletions workspace/suphnx-reward-shaping/tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import json
import os
import sys

from google.protobuf import json_format

sys.path.append("../../../")
import mjxproto

sys.path.append("../")
from utils import to_dataset

mjxprotp_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "resources")


def test_to_dataset():
num_resources = len(os.listdir(mjxprotp_dir))
features, scores = to_dataset(mjxprotp_dir)
assert features.shape == (num_resources, 6)
assert scores.shape == (num_resources, 1)
Empty file.
Empty file.
66 changes: 66 additions & 0 deletions workspace/suphnx-reward-shaping/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import json
import os
import random
import sys
from typing import Dict, Iterator, List, Optional, Tuple

import jax
import jax.numpy as jnp
import numpy as np
from google.protobuf import json_format

sys.path.append("../../")
import mjxproto


def to_dataset(mjxprotp_dir: str) -> Tuple[jnp.ndarray, jnp.ndarray]:
"""
jsonが入っているディレクトリを引数としてjax.numpyのデータセットを作る.
"""
features: List = []
scores: List = []
for _json in os.listdir(mjxprotp_dir):
_json = os.path.join(mjxprotp_dir, _json)
assert ".json" in _json
with open(_json, "r") as f:
lines = f.readlines()
_dicts = [json.loads(round) for round in lines]
states = [json_format.ParseDict(d, mjxproto.State()) for d in _dicts]
target: int = random.randint(0, 3)
features.append(to_feature(states, target))
scores.append(to_final_scores(states, target))
features_array: jnp.ndarray = jnp.array(features)
scores_array: jnp.ndarray = jnp.array(scores)
return features_array, scores_array


def _select_one_round(states: List[mjxproto.State]) -> mjxproto.State:
"""
データセットに本質的で無い相関が生まれることを防ぐために一半荘につき1ペアのみを使う.
"""
idx: int = random.randint(0, len(states) - 1)
return states[idx]


def _calc_curr_pos(init_pos: int, round: int) -> int:
return init_pos + round % 4


def to_feature(states: List[mjxproto.State], target) -> List[int]:
"""
特徴量 = [終了時の点数, 自風, 親, 局, 本場, 詰み棒]
"""
state = _select_one_round(states)
ten: int = state.round_terminal.final_score.tens[target]
honba: int = state.round_terminal.final_score.honba
tsumibo: int = state.round_terminal.final_score.riichi
round: int = state.round_terminal.final_score.round
wind: int = _calc_curr_pos(target, round)
oya: int = _calc_curr_pos(0, round)
return [ten, honba, tsumibo, round, wind, oya]


def to_final_scores(states: List[mjxproto.State], target) -> List[int]:
final_state = states[-1]
final_score = final_state.round_terminal.final_score.tens[target]
return [final_score]