Skip to content

Commit

Permalink
Merge pull request #706 from toshihikoyanase/support-numpy-scalar-in-…
Browse files Browse the repository at this point in the history
…user-attrs

Support `numpy` scalars in `Trial.user_attrs`
  • Loading branch information
c-bata authored Nov 27, 2023
2 parents cacae81 + ad4d29f commit 1759389
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 4 deletions.
4 changes: 2 additions & 2 deletions optuna_dashboard/_cached_extra_study_property.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import copy
import numbers
import threading
from typing import List
from typing import Optional
Expand Down Expand Up @@ -85,9 +86,8 @@ def update(self, trials: list[FrozenTrial]) -> None:
self._cursor = next_cursor

def _update_user_attrs(self, trial: FrozenTrial) -> None:
# TODO(c-bata): Support numpy-specific number types.
current_user_attrs = {
k: not isinstance(v, bool) and isinstance(v, (int, float))
k: not isinstance(v, bool) and isinstance(v, numbers.Real)
for k, v in trial.user_attrs.items()
}
for attr_name, current_is_sortable in current_user_attrs.items():
Expand Down
3 changes: 3 additions & 0 deletions optuna_dashboard/_serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from datetime import datetime
import json
import numbers
from typing import Any
from typing import TYPE_CHECKING
from typing import Union
Expand Down Expand Up @@ -104,6 +105,8 @@ def serialize_attrs(attrs: dict[str, Any]) -> list[Attribute]:
value = "<binary object>"
elif isinstance(v, str):
value = v
elif isinstance(v, numbers.Real):
value = str(v)
else:
value = json.dumps(v)
value = value[:MAX_ATTR_LENGTH] if len(value) > MAX_ATTR_LENGTH else value
Expand Down
23 changes: 21 additions & 2 deletions python_tests/test_cached_extra_study_property.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from unittest import TestCase
import warnings

import numpy as np
import optuna
from optuna import create_trial
from optuna.distributions import BaseDistribution
Expand Down Expand Up @@ -254,11 +255,29 @@ def test_contains_failed_trials(self) -> None:

def test_infer_sortable(self) -> None:
user_attrs_list: list[dict[str, Any]] = [
{"a": 1, "b": 1, "c": 1, "d": "a", "e": 1, "f": True},
{
"a": 1,
"b": 1,
"c": 1,
"d": "a",
"e": 1,
"f": True,
"g": np.float128(1.1),
"h": np.int64(2),
},
{"a": 2, "b": "a", "c": "a", "d": "a"},
{"a": 3, "b": None, "c": 3, "d": "a", "e": 3},
]
expected = {"a": True, "b": False, "c": False, "d": False, "e": True, "f": False}
expected = {
"a": True,
"b": False,
"c": False,
"d": False,
"e": True,
"f": False,
"g": True,
"h": True,
}

trials = []
for user_attrs in user_attrs_list:
Expand Down
27 changes: 27 additions & 0 deletions python_tests/test_serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import sys

import numpy as np
import optuna
from optuna_dashboard._serializer import serialize_attrs
from optuna_dashboard._serializer import serialize_study_detail
Expand All @@ -25,6 +26,32 @@ def test_serialize_dict() -> None:
assert len(serialized) <= 1


def test_serialize_numpy_integer() -> None:
serialized = serialize_attrs(
{
"int8": np.int8(1),
"int16": np.int16(1),
"int32": np.int32(1),
"int64": np.int64(1),
}
)
assert len(serialized) == 4
assert all([v["value"] == "1" for v in serialized])


def test_serialize_numpy_floating() -> None:
serialized = serialize_attrs(
{
"float16": np.float16(1.0),
"float32": np.float32(1.0),
"float64": np.float64(1.0),
"float128": np.float128(1.0),
}
)
assert len(serialized) == 4
assert all([v["value"] == "1.0" for v in serialized])


@pytest.mark.skipif(sys.version_info < (3, 8), reason="BoTorch dropped Python3.7 support")
def test_get_study_detail_is_preferential() -> None:
storage = optuna.storages.InMemoryStorage()
Expand Down

0 comments on commit 1759389

Please sign in to comment.