Skip to content

Commit

Permalink
Use sequential keys to make keys distribute better
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 703450511
  • Loading branch information
tomvdw authored and The TensorFlow Datasets Authors committed Dec 6, 2024
1 parent d2ad852 commit 824cfbb
Showing 1 changed file with 13 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
# The years in the dataset.
YEARS = [19, 20, 21, 22, 23, 24]

_MAX_EXAMPLES_PER_DAY = 24 * 12 # 288 per day

_REWARD_RESPONSES = [
'agentRewardValue',
'productivityReward',
Expand Down Expand Up @@ -196,18 +198,20 @@ def _generate_examples(self, path: epath.Path, year: int, pipeline):

return (
pipeline
| f'CreateDates_{year}' >> beam.Create(all_dates)
| f'CreateDates_{year}' >> beam.Create(enumerate(all_dates))
| f'ProcessDate_{year}' >> beam.FlatMap(process_date, path=path)
| f'Reshuffle_{year}' >> beam.Reshuffle()
)


def process_date(
start_time: pd.Timestamp,
day_index_and_start_time: tuple[int, pd.Timestamp],
path: epath.Path,
) -> Iterable[tuple[int, dict[str, Any]]]:
"""Process a single date."""
day_index, start_time = day_index_and_start_time
end_time = start_time + pd.Timedelta(hours=23)
key_offset = day_index * _MAX_EXAMPLES_PER_DAY

reader = controller_reader.ProtoReader(path)
observation_responses = reader.read_observation_responses(
Expand All @@ -230,6 +234,12 @@ def process_date(
key=lambda o: to_ns_timestamp(o.start_timestamp),
)

if len(observation_responses) > _MAX_EXAMPLES_PER_DAY:
raise ValueError(
f'Too many observation responses for date {start_time}: '
f'{len(observation_responses)} > {_MAX_EXAMPLES_PER_DAY}'
)

for i in range(len(observation_responses)):
observation_response = json_format.MessageToDict(observation_responses[i])
action_response = json_format.MessageToDict(action_responses[i])
Expand All @@ -256,7 +266,7 @@ def process_date(
reward_response[val] = -1 # sentinal value

beam.metrics.Metrics.counter(f'date_{start_time}', 'example_count').inc()
key = int(f'{start_time.toordinal()}{i:05d}')
key = key_offset + i
yield key, {
'observation': observation_response,
'action': action_response,
Expand Down

0 comments on commit 824cfbb

Please sign in to comment.