Skip to content

Commit

Permalink
Nextgen Proto Pythonic API: Timestamp/Duration assignment, creation a…
Browse files Browse the repository at this point in the history
…nd calculation

Timestamp and Duration are now have more support with datetime and timedelta:
- Allows assign python datetime to protobuf DateTime field in addition to current FromDatetime/ToDatetime (Note: will throw exceptions for the differences in supported ranges)
- Allows assign python timedelta to protobuf Duration field in addition to current FromTimedelta/ToTimedelta
- Calculation between Timestamp, Duration, datetime and timedelta will also be supported.

example usage:

from datetime import datetime, timedelta
from event_pb2 import Event
e = Event(start_time=datetime(year=2112, month=2, day=3),
          duration=timedelta(hours=10))
duration = timedelta(hours=10))
end_time = e.start_time + timedelta(hours=4)
e.duration = end_time - e.start_time
PiperOrigin-RevId: 640639168
  • Loading branch information
anandolee authored and copybara-github committed Jun 5, 2024
1 parent a450c9c commit b690e72
Show file tree
Hide file tree
Showing 7 changed files with 445 additions and 31 deletions.
10 changes: 10 additions & 0 deletions python/google/protobuf/internal/descriptor_pool_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
from google.protobuf.internal import no_package_pb2
from google.protobuf.internal import testing_refleaks

from google.protobuf import duration_pb2
from google.protobuf import timestamp_pb2
from google.protobuf import unittest_features_pb2
from google.protobuf import unittest_import_pb2
from google.protobuf import unittest_import_public_pb2
Expand Down Expand Up @@ -435,6 +437,8 @@ def testAddSerializedFile(self):
self.assertEqual(file2.name,
'google/protobuf/internal/factory_test2.proto')
self.testFindMessageTypeByName()
self.pool.AddSerializedFile(timestamp_pb2.DESCRIPTOR.serialized_pb)
self.pool.AddSerializedFile(duration_pb2.DESCRIPTOR.serialized_pb)
file_json = self.pool.AddSerializedFile(
more_messages_pb2.DESCRIPTOR.serialized_pb)
field = file_json.message_types_by_name['class'].fields_by_name['int_field']
Expand Down Expand Up @@ -542,12 +546,18 @@ def testComplexNesting(self):
# that uses a DescriptorDatabase.
# TODO: Fix python and cpp extension diff.
return
timestamp_desc = descriptor_pb2.FileDescriptorProto.FromString(
timestamp_pb2.DESCRIPTOR.serialized_pb)
duration_desc = descriptor_pb2.FileDescriptorProto.FromString(
duration_pb2.DESCRIPTOR.serialized_pb)
more_messages_desc = descriptor_pb2.FileDescriptorProto.FromString(
more_messages_pb2.DESCRIPTOR.serialized_pb)
test1_desc = descriptor_pb2.FileDescriptorProto.FromString(
descriptor_pool_test1_pb2.DESCRIPTOR.serialized_pb)
test2_desc = descriptor_pb2.FileDescriptorProto.FromString(
descriptor_pool_test2_pb2.DESCRIPTOR.serialized_pb)
self.pool.Add(timestamp_desc)
self.pool.Add(duration_desc)
self.pool.Add(more_messages_desc)
self.pool.Add(test1_desc)
self.pool.Add(test2_desc)
Expand Down
8 changes: 8 additions & 0 deletions python/google/protobuf/internal/more_messages.proto
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ syntax = "proto2";

package google.protobuf.internal;

import "google/protobuf/duration.proto";
import "google/protobuf/timestamp.proto";

// A message where tag numbers are listed out of order, to allow us to test our
// canonicalization of serialized output, which should always be in tag order.
// We also mix in some extensions for extra fun.
Expand Down Expand Up @@ -348,3 +351,8 @@ message ConflictJsonName {
optional int32 value = 1 [json_name = "old_value"];
optional int32 new_value = 2 [json_name = "value"];
}

message WKTMessage {
optional Timestamp optional_timestamp = 1;
optional Duration optional_duration = 2;
}
43 changes: 35 additions & 8 deletions python/google/protobuf/internal/python_message.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

__author__ = '[email protected] (Will Robinson)'

import datetime
from io import BytesIO
import struct
import sys
Expand Down Expand Up @@ -536,13 +537,30 @@ def init(self, **kwargs):
self._fields[field] = copy
elif field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
copy = field._default_constructor(self)
new_val = field_value
if isinstance(field_value, dict):
new_val = None
if isinstance(field_value, message_mod.Message):
new_val = field_value
elif isinstance(field_value, dict):
new_val = field.message_type._concrete_class(**field_value)
try:
copy.MergeFrom(new_val)
except TypeError:
_ReraiseTypeErrorWithFieldName(message_descriptor.name, field_name)
elif field.message_type.full_name == 'google.protobuf.Timestamp':
copy.FromDatetime(field_value)
elif field.message_type.full_name == 'google.protobuf.Duration':
copy.FromTimedelta(field_value)
else:
raise TypeError(
'Message field {0}.{1} must be initialized with a '
'dict or instance of same class, got {2}.'.format(
message_descriptor.name,
field_name,
type(field_value).__name__,
)
)

if new_val:
try:
copy.MergeFrom(new_val)
except TypeError:
_ReraiseTypeErrorWithFieldName(message_descriptor.name, field_name)
self._fields[field] = copy
else:
if field.cpp_type == _FieldDescriptor.CPPTYPE_ENUM:
Expand Down Expand Up @@ -753,8 +771,17 @@ def getter(self):
# We define a setter just so we can throw an exception with a more
# helpful error message.
def setter(self, new_value):
raise AttributeError('Assignment not allowed to composite field '
'"%s" in protocol message object.' % proto_field_name)
if field.message_type.full_name == 'google.protobuf.Timestamp':
getter(self)
self._fields[field].FromDatetime(new_value)
elif field.message_type.full_name == 'google.protobuf.Duration':
getter(self)
self._fields[field].FromTimedelta(new_value)
else:
raise AttributeError(
'Assignment not allowed to composite field '
'"%s" in protocol message object.' % proto_field_name
)

# Add a property to encapsulate the getter.
doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name
Expand Down
51 changes: 46 additions & 5 deletions python/google/protobuf/internal/well_known_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@
import collections.abc
import datetime
import warnings

from google.protobuf.internal import field_mask
from typing import Union

FieldMask = field_mask.FieldMask

Expand Down Expand Up @@ -271,12 +271,35 @@ def FromDatetime(self, dt):
# manipulated into a long value of seconds. During the conversion from
# struct_time to long, the source date in UTC, and so it follows that the
# correct transformation is calendar.timegm()
seconds = calendar.timegm(dt.utctimetuple())
nanos = dt.microsecond * _NANOS_PER_MICROSECOND
try:
seconds = calendar.timegm(dt.utctimetuple())
nanos = dt.microsecond * _NANOS_PER_MICROSECOND
except AttributeError as e:
raise AttributeError(
'Fail to convert to Timestamp. Expected a datetime like '
'object got {0} : {1}'.format(type(dt).__name__, e)
) from e
_CheckTimestampValid(seconds, nanos)
self.seconds = seconds
self.nanos = nanos

def __add__(self, value) -> datetime.datetime:
if isinstance(value, Duration):
return self.ToDatetime() + value.ToTimedelta()
return self.ToDatetime() + value

__radd__ = __add__

def __sub__(self, value) -> Union[datetime.datetime, datetime.timedelta]:
if isinstance(value, Timestamp):
return self.ToDatetime() - value.ToDatetime()
elif isinstance(value, Duration):
return self.ToDatetime() - value.ToTimedelta()
return self.ToDatetime() - value

def __rsub__(self, dt) -> datetime.timedelta:
return dt - self.ToDatetime()


def _CheckTimestampValid(seconds, nanos):
if seconds < _TIMESTAMP_SECONDS_MIN or seconds > _TIMESTAMP_SECONDS_MAX:
Expand Down Expand Up @@ -408,8 +431,16 @@ def ToTimedelta(self) -> datetime.timedelta:

def FromTimedelta(self, td):
"""Converts timedelta to Duration."""
self._NormalizeDuration(td.seconds + td.days * _SECONDS_PER_DAY,
td.microseconds * _NANOS_PER_MICROSECOND)
try:
self._NormalizeDuration(
td.seconds + td.days * _SECONDS_PER_DAY,
td.microseconds * _NANOS_PER_MICROSECOND,
)
except AttributeError as e:
raise AttributeError(
'Fail to convert to Duration. Expected a timedelta like '
'object got {0}: {1}'.format(type(td).__name__, e)
) from e

def _NormalizeDuration(self, seconds, nanos):
"""Set Duration by seconds and nanos."""
Expand All @@ -420,6 +451,16 @@ def _NormalizeDuration(self, seconds, nanos):
self.seconds = seconds
self.nanos = nanos

def __add__(self, value) -> Union[datetime.datetime, datetime.timedelta]:
if isinstance(value, Timestamp):
return self.ToTimedelta() + value.ToDatetime()
return self.ToTimedelta() + value

__radd__ = __add__

def __rsub__(self, dt) -> Union[datetime.datetime, datetime.timedelta]:
return dt - self.ToTimedelta()


def _CheckDurationValid(seconds, nanos):
if seconds < -_DURATION_SECONDS_MAX or seconds > _DURATION_SECONDS_MAX:
Expand Down
Loading

0 comments on commit b690e72

Please sign in to comment.