diff --git a/tensorflow_datasets/datasets/smart_buildings/controller_reader.py b/tensorflow_datasets/datasets/smart_buildings/controller_reader.py index fba69d9b53b..119da9b6c38 100644 --- a/tensorflow_datasets/datasets/smart_buildings/controller_reader.py +++ b/tensorflow_datasets/datasets/smart_buildings/controller_reader.py @@ -22,6 +22,7 @@ from typing import Callable, Mapping, Sequence, TypeVar, Union from absl import logging +from etils import epath import pandas as pd from tensorflow_datasets.datasets.smart_buildings import constants from tensorflow_datasets.datasets.smart_buildings import reader_lib @@ -45,8 +46,8 @@ class ProtoReader(reader_lib.BaseReader): input_dir: directory path where the files are located """ - def __init__(self, input_dir): - self._input_dir = input_dir + def __init__(self, input_dir: epath.PathLike): + self._input_dir = epath.Path(input_dir) logging.info('Reader lib input directory %s', self._input_dir) def read_observation_responses( @@ -97,7 +98,7 @@ def read_reward_responses( # pytype: disable=signature-mismatch # overriding-r def read_zone_infos(self) -> Sequence[smart_control_building_pb2.ZoneInfo]: """Reads the zone infos for the Building from .pbtxt.""" - filename = os.path.join(self._input_dir, constants.ZONE_INFO_PREFIX) + filename = self._input_dir / constants.ZONE_INFO_PREFIX return self._read_streamed_protos( filename, smart_control_building_pb2.ZoneInfo.FromString ) @@ -107,7 +108,7 @@ def read_device_infos( ) -> Sequence[smart_control_building_pb2.DeviceInfo]: """Reads the device infos for the Building.""" - filename = os.path.join(self._input_dir, constants.DEVICE_INFO_PREFIX) + filename = self._input_dir / constants.DEVICE_INFO_PREFIX return self._read_streamed_protos( filename, smart_control_building_pb2.DeviceInfo.FromString ) @@ -141,28 +142,26 @@ def _read_messages( messages.extend(file_messages) return messages - def _read_shards(self, input_dir: str, file_prefix: str) -> Sequence[str]: + def _read_shards( + self, input_dir: epath.Path, file_prefix: str + ) -> Sequence[epath.Path]: """Returns full paths in input_dir of files starting with file_prefix.""" - - shards = [ - os.path.join(input_dir, f) - for f in os.listdir(input_dir) - if f.startswith(file_prefix) - ] - return shards + return list(epath.Path(input_dir).glob(f'{file_prefix}*')) def _select_shards( self, start_time: pd.Timestamp, end_time: pd.Timestamp, - shards: Sequence[str], - ) -> Sequence[str]: + shards: Sequence[epath.Path], + ) -> Sequence[epath.Path]: """Returns the shards that fall inside the start and end times.""" - def _read_timestamp(filepath: str) -> pd.Timestamp: + def _read_timestamp(filepath: epath.Path) -> pd.Timestamp: """Reads the timestamp from the filepath.""" assert filepath - ts = pd.Timestamp(re.findall(r'\d{4}\.\d{2}\.\d{2}\.\d{2}', filepath)[-1]) + ts = pd.Timestamp( + re.findall(r'\d{4}\.\d{2}\.\d{2}\.\d{2}', os.fspath(filepath))[-1] + ) return ts def _between( @@ -179,13 +178,13 @@ def _between( def _read_streamed_protos( self, - full_path: str, + full_path: epath.Path, from_string_func: Callable[[Union[bytearray, bytes, memoryview]], T], ) -> Sequence[T]: """Reads a proto which has byte size preceding the message.""" messages = [] - with open(full_path, 'rb') as f: + with full_path.open('rb') as f: while True: # Read size as a varint size_bytes = f.read(4) @@ -260,7 +259,7 @@ def get_episode_data(working_dir: str) -> pd.DataFrame: Returns: A dataframe with episode label, timestamps, number of updates. """ - episode_dirs = os.listdir(working_dir) + episode_dirs = list(epath.Path(working_dir).iterdir()) date_extractor = operator.itemgetter(slice(-13, None)) execution_times = pd.to_datetime(