Skip to content

Commit

Permalink
Use epath instead of os to read files
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 696497054
  • Loading branch information
tomvdw authored and The TensorFlow Datasets Authors committed Nov 14, 2024
1 parent c3c62b8 commit 40cc288
Showing 1 changed file with 18 additions and 19 deletions.
37 changes: 18 additions & 19 deletions tensorflow_datasets/datasets/smart_buildings/controller_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from typing import Callable, Mapping, Sequence, TypeVar, Union

from absl import logging
from etils import epath
import pandas as pd
from tensorflow_datasets.datasets.smart_buildings import constants
from tensorflow_datasets.datasets.smart_buildings import reader_lib
Expand All @@ -45,8 +46,8 @@ class ProtoReader(reader_lib.BaseReader):
input_dir: directory path where the files are located
"""

def __init__(self, input_dir):
self._input_dir = input_dir
def __init__(self, input_dir: epath.PathLike):
self._input_dir = epath.Path(input_dir)
logging.info('Reader lib input directory %s', self._input_dir)

def read_observation_responses(
Expand Down Expand Up @@ -97,7 +98,7 @@ def read_reward_responses( # pytype: disable=signature-mismatch # overriding-r

def read_zone_infos(self) -> Sequence[smart_control_building_pb2.ZoneInfo]:
"""Reads the zone infos for the Building from .pbtxt."""
filename = os.path.join(self._input_dir, constants.ZONE_INFO_PREFIX)
filename = self._input_dir / constants.ZONE_INFO_PREFIX
return self._read_streamed_protos(
filename, smart_control_building_pb2.ZoneInfo.FromString
)
Expand All @@ -107,7 +108,7 @@ def read_device_infos(
) -> Sequence[smart_control_building_pb2.DeviceInfo]:
"""Reads the device infos for the Building."""

filename = os.path.join(self._input_dir, constants.DEVICE_INFO_PREFIX)
filename = self._input_dir / constants.DEVICE_INFO_PREFIX
return self._read_streamed_protos(
filename, smart_control_building_pb2.DeviceInfo.FromString
)
Expand Down Expand Up @@ -141,28 +142,26 @@ def _read_messages(
messages.extend(file_messages)
return messages

def _read_shards(self, input_dir: str, file_prefix: str) -> Sequence[str]:
def _read_shards(
self, input_dir: epath.Path, file_prefix: str
) -> Sequence[epath.Path]:
"""Returns full paths in input_dir of files starting with file_prefix."""

shards = [
os.path.join(input_dir, f)
for f in os.listdir(input_dir)
if f.startswith(file_prefix)
]
return shards
return list(epath.Path(input_dir).glob(f'{file_prefix}*'))

def _select_shards(
self,
start_time: pd.Timestamp,
end_time: pd.Timestamp,
shards: Sequence[str],
) -> Sequence[str]:
shards: Sequence[epath.Path],
) -> Sequence[epath.Path]:
"""Returns the shards that fall inside the start and end times."""

def _read_timestamp(filepath: str) -> pd.Timestamp:
def _read_timestamp(filepath: epath.Path) -> pd.Timestamp:
"""Reads the timestamp from the filepath."""
assert filepath
ts = pd.Timestamp(re.findall(r'\d{4}\.\d{2}\.\d{2}\.\d{2}', filepath)[-1])
ts = pd.Timestamp(
re.findall(r'\d{4}\.\d{2}\.\d{2}\.\d{2}', os.fspath(filepath))[-1]
)
return ts

def _between(
Expand All @@ -179,13 +178,13 @@ def _between(

def _read_streamed_protos(
self,
full_path: str,
full_path: epath.Path,
from_string_func: Callable[[Union[bytearray, bytes, memoryview]], T],
) -> Sequence[T]:
"""Reads a proto which has byte size preceding the message."""

messages = []
with open(full_path, 'rb') as f:
with full_path.open('rb') as f:
while True:
# Read size as a varint
size_bytes = f.read(4)
Expand Down Expand Up @@ -260,7 +259,7 @@ def get_episode_data(working_dir: str) -> pd.DataFrame:
Returns:
A dataframe with episode label, timestamps, number of updates.
"""
episode_dirs = os.listdir(working_dir)
episode_dirs = list(epath.Path(working_dir).iterdir())
date_extractor = operator.itemgetter(slice(-13, None))

execution_times = pd.to_datetime(
Expand Down

0 comments on commit 40cc288

Please sign in to comment.