diff --git a/sdks/go/pkg/beam/transforms/periodic/periodic.go b/sdks/go/pkg/beam/transforms/periodic/periodic.go index 5a7b4d0cf536..84ee50f2ae6c 100644 --- a/sdks/go/pkg/beam/transforms/periodic/periodic.go +++ b/sdks/go/pkg/beam/transforms/periodic/periodic.go @@ -61,6 +61,15 @@ func NewSequenceDefinition(start, end time.Time, interval time.Duration) Sequenc } } +func CalculateByteSizeOfSequence(now time.Time, sd SequenceDefinition, rest offsetrange.Restriction) int64 { + // Find the # of outputs expected for overlap of and [-inf, now) + nowIndex := int64(now.Sub(mtime.Time(sd.Start).ToTime()) / sd.Interval) + if nowIndex < rest.Start { + return 0 + } + return 8 * (min(rest.End, nowIndex) - rest.Start) +} + type sequenceGenDoFn struct{} func (fn *sequenceGenDoFn) CreateInitialRestriction(sd SequenceDefinition) offsetrange.Restriction { @@ -75,8 +84,8 @@ func (fn *sequenceGenDoFn) CreateTracker(rest offsetrange.Restriction) *sdf.Lock return sdf.NewLockRTracker(offsetrange.NewTracker(rest)) } -func (fn *sequenceGenDoFn) RestrictionSize(_ SequenceDefinition, rest offsetrange.Restriction) float64 { - return rest.Size() +func (fn *sequenceGenDoFn) RestrictionSize(sd SequenceDefinition, rest offsetrange.Restriction) float64 { + return float64(CalculateByteSizeOfSequence(time.Now(), sd, rest)) } func (fn *sequenceGenDoFn) SplitRestriction(_ SequenceDefinition, rest offsetrange.Restriction) []offsetrange.Restriction { diff --git a/sdks/go/pkg/beam/transforms/periodic/periodic_test.go b/sdks/go/pkg/beam/transforms/periodic/periodic_test.go index a34edf8d07b8..56dba8778684 100644 --- a/sdks/go/pkg/beam/transforms/periodic/periodic_test.go +++ b/sdks/go/pkg/beam/transforms/periodic/periodic_test.go @@ -21,6 +21,7 @@ import ( "time" "github.com/apache/beam/sdks/v2/go/pkg/beam" + "github.com/apache/beam/sdks/v2/go/pkg/beam/io/rtrackers/offsetrange" "github.com/apache/beam/sdks/v2/go/pkg/beam/options/jobopts" _ "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/prism" "github.com/apache/beam/sdks/v2/go/pkg/beam/testing/passert" @@ -56,3 +57,38 @@ func TestImpulse(t *testing.T) { passert.Count(s, out, "SecondsInMinute", 60) ptest.RunAndValidate(t, p) } + +func TestSize(t *testing.T) { + sd := SequenceDefinition{ + Interval: 10 * time.Second, + Start: 0, + End: 1000 * time.Minute.Milliseconds(), + } + end := int64((1000 * time.Minute) / (10 * time.Second)) + + sizeTests := []struct { + now, startIndex, endIndex, want int64 + }{ + {100, 10, end, 0}, + {100, 9, end, 8}, + {100, 8, end, 16}, + {101, 9, end, 8}, + {10000, 0, end, 8 * 10000 / 10}, + {10000, 1002, 1003, 0}, + {10100, 1002, 1003, 8}, + } + + for _, test := range sizeTests { + got := CalculateByteSizeOfSequence( + time.Unix(test.now, 0), + sd, + offsetrange.Restriction{ + Start: int64(test.startIndex), + End: int64(test.endIndex), + }) + if got != test.want { + t.Errorf("TestBytes(%v, %v, %v) = %v, want %v", + test.now, test.startIndex, test.endIndex, got, test.want) + } + } +} diff --git a/sdks/python/apache_beam/transforms/periodicsequence.py b/sdks/python/apache_beam/transforms/periodicsequence.py index b2d7b375571b..61c9aacd920c 100644 --- a/sdks/python/apache_beam/transforms/periodicsequence.py +++ b/sdks/python/apache_beam/transforms/periodicsequence.py @@ -48,14 +48,27 @@ def initial_restriction(self, element): def create_tracker(self, restriction): return OffsetRestrictionTracker(restriction) - def restriction_size(self, unused_element, restriction): - return restriction.size() + def restriction_size(self, element, restriction): + return sequence_backlog_bytes(element, time.time(), restriction) # On drain, immediately stop emitting new elements def truncate(self, unused_element, unused_restriction): return None +def sequence_backlog_bytes(element, now, offset_range): + # Find the # of outputs expected for overlap of and [-inf, now) + start, _, interval = element + if isinstance(start, Timestamp): + start = start.micros / 1000000 + assert interval > 0 + + now_index = math.floor((now - start) / interval) + if now_index < offset_range.start: + return 0 + return 8 * (min(offset_range.stop, now_index) - offset_range.start) + + class ImpulseSeqGenDoFn(beam.DoFn): ''' ImpulseSeqGenDoFn fn receives tuple elements with three parts: diff --git a/sdks/python/apache_beam/transforms/periodicsequence_test.py b/sdks/python/apache_beam/transforms/periodicsequence_test.py index 932779555341..58b3c2a61b87 100644 --- a/sdks/python/apache_beam/transforms/periodicsequence_test.py +++ b/sdks/python/apache_beam/transforms/periodicsequence_test.py @@ -24,11 +24,13 @@ import unittest import apache_beam as beam +from apache_beam.io.restriction_trackers import OffsetRange from apache_beam.testing.test_pipeline import TestPipeline from apache_beam.testing.util import assert_that from apache_beam.testing.util import equal_to from apache_beam.transforms.periodicsequence import PeriodicImpulse from apache_beam.transforms.periodicsequence import PeriodicSequence +from apache_beam.transforms.periodicsequence import sequence_backlog_bytes # Disable frequent lint warning due to pipe operator for chaining transforms. # pylint: disable=expression-not-assigned @@ -112,6 +114,24 @@ def test_periodicsequence_outputs_valid_sequence_in_past(self): self.assertEqual(result.is_bounded, False) assert_that(result, equal_to(k)) + def test_periodicsequence_output_size(self): + element = [0, 1000000000, 10] + self.assertEqual( + sequence_backlog_bytes(element, 100, OffsetRange(10, 100000000)), 0) + self.assertEqual( + sequence_backlog_bytes(element, 100, OffsetRange(9, 100000000)), 8) + self.assertEqual( + sequence_backlog_bytes(element, 100, OffsetRange(8, 100000000)), 16) + self.assertEqual( + sequence_backlog_bytes(element, 101, OffsetRange(9, 100000000)), 8) + self.assertEqual( + sequence_backlog_bytes(element, 10000, OffsetRange(0, 100000000)), + 8 * 10000 / 10) + self.assertEqual( + sequence_backlog_bytes(element, 10000, OffsetRange(1002, 1003)), 0) + self.assertEqual( + sequence_backlog_bytes(element, 10100, OffsetRange(1002, 1003)), 8) + if __name__ == '__main__': unittest.main()