Skip to content

Commit

Permalink
fix: Vizier - Fixed pyvizier client study creation errors
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 544186919
  • Loading branch information
vertex-sdk-bot authored and copybara-github committed Jun 28, 2023
1 parent 69aaf01 commit 16299d1
Show file tree
Hide file tree
Showing 2 changed files with 451 additions and 38 deletions.
66 changes: 39 additions & 27 deletions google/cloud/aiplatform/vizier/pyvizier/proto_converters.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
"""Converters for OSS Vizier's protos from/to PyVizier's classes."""
import datetime
import logging
from datetime import timezone
from typing import List, Optional, Sequence, Tuple, Union

from google.protobuf import duration_pb2
from google.protobuf import struct_pb2
from google.protobuf import timestamp_pb2
from google.cloud.aiplatform.compat.types import study as study_pb2
from google.cloud.aiplatform.vizier.pyvizier import ScaleType
from google.cloud.aiplatform.vizier.pyvizier import ParameterType
Expand Down Expand Up @@ -80,8 +82,8 @@ def _set_default_value(
default_value: Union[float, int, str],
):
"""Sets the protos' default_value field."""
which_pv_spec = proto.WhichOneof("parameter_value_spec")
getattr(proto, which_pv_spec).default_value.value = default_value
which_pv_spec = proto._pb.WhichOneof("parameter_value_spec")
getattr(proto, which_pv_spec).default_value = default_value

@classmethod
def _matching_parent_values(
Expand Down Expand Up @@ -280,17 +282,16 @@ def to_proto(
cls, parameter_value: ParameterValue, name: str
) -> study_pb2.Trial.Parameter:
"""Returns Parameter Proto."""
proto = study_pb2.Trial.Parameter(parameter_id=name)

if isinstance(parameter_value.value, int):
proto.value.number_value = parameter_value.value
value = struct_pb2.Value(number_value=parameter_value.value)
elif isinstance(parameter_value.value, bool):
proto.value.bool_value = parameter_value.value
value = struct_pb2.Value(bool_value=parameter_value.value)
elif isinstance(parameter_value.value, float):
proto.value.number_value = parameter_value.value
value = struct_pb2.Value(number_value=parameter_value.value)
elif isinstance(parameter_value.value, str):
proto.value.string_value = parameter_value.value
value = struct_pb2.Value(string_value=parameter_value.value)

proto = study_pb2.Trial.Parameter(parameter_id=name, value=value)
return proto


Expand Down Expand Up @@ -340,18 +341,19 @@ def from_proto(cls, proto: study_pb2.Measurement) -> Measurement:
@classmethod
def to_proto(cls, measurement: Measurement) -> study_pb2.Measurement:
"""Converts to Measurement proto."""
proto = study_pb2.Measurement()
int_seconds = int(measurement.elapsed_secs)
proto = study_pb2.Measurement(
elapsed_duration=duration_pb2.Duration(
seconds=int_seconds,
nanos=int(1e9 * (measurement.elapsed_secs - int_seconds)),
)
)
for name, metric in measurement.metrics.items():
proto.metrics.append(
study_pb2.Measurement.Metric(metric_id=name, value=metric.value)
)

proto.step_count = measurement.steps
int_seconds = int(measurement.elapsed_secs)
proto.elapsed_duration = duration_pb2.Duration(
seconds=int_seconds,
nanos=int(1e9 * (measurement.elapsed_secs - int_seconds)),
)
return proto


Expand Down Expand Up @@ -426,8 +428,11 @@ def from_proto(cls, proto: study_pb2.Trial) -> Trial:
infeasibility_reason = None
if proto.state == study_pb2.Trial.State.SUCCEEDED:
if proto.end_time:
completion_ts = proto.end_time.nanosecond / 1e9
completion_time = datetime.datetime.fromtimestamp(completion_ts)
completion_time = (
proto.end_time.timestamp_pb()
.ToDatetime()
.replace(tzinfo=timezone.utc)
)
elif proto.state == study_pb2.Trial.State.INFEASIBLE:
infeasibility_reason = proto.infeasible_reason

Expand All @@ -437,8 +442,11 @@ def from_proto(cls, proto: study_pb2.Trial) -> Trial:

creation_time = None
if proto.start_time:
creation_ts = proto.start_time.nanosecond / 1e9
creation_time = datetime.datetime.fromtimestamp(creation_ts)
creation_time = (
proto.start_time.timestamp_pb()
.ToDatetime()
.replace(tzinfo=timezone.utc)
)
return Trial(
id=int(proto.name.split("/")[-1]),
description=proto.name,
Expand Down Expand Up @@ -481,22 +489,26 @@ def to_proto(cls, pytrial: Trial) -> study_pb2.Trial:

# pytrial always adds an empty metric. Ideally, we should remove it if the
# metric does not exist in the study config.
# setattr() is required here as `proto.final_measurement.CopyFrom`
# raises AttributeErrors when setting the field on the pb2 compat types.
if pytrial.final_measurement is not None:
proto.final_measurement.CopyFrom(
MeasurementConverter.to_proto(pytrial.final_measurement)
setattr(
proto,
"final_measurement",
MeasurementConverter.to_proto(pytrial.final_measurement),
)

for measurement in pytrial.measurements:
proto.measurements.append(MeasurementConverter.to_proto(measurement))

if pytrial.creation_time is not None:
creation_secs = datetime.datetime.timestamp(pytrial.creation_time)
proto.start_time.seconds = int(creation_secs)
proto.start_time.nanos = int(1e9 * (creation_secs - int(creation_secs)))
start_time = timestamp_pb2.Timestamp()
start_time.FromDatetime(pytrial.creation_time)
setattr(proto, "start_time", start_time)
if pytrial.completion_time is not None:
completion_secs = datetime.datetime.timestamp(pytrial.completion_time)
proto.end_time.seconds = int(completion_secs)
proto.end_time.nanos = int(1e9 * (completion_secs - int(completion_secs)))
end_time = timestamp_pb2.Timestamp()
end_time.FromDatetime(pytrial.completion_time)
setattr(proto, "end_time", end_time)
if pytrial.infeasibility_reason is not None:
proto.infeasible_reason = pytrial.infeasibility_reason
return proto
Loading

0 comments on commit 16299d1

Please sign in to comment.