diff --git a/pdr_backend/lake/csv_data_store.py b/pdr_backend/lake/csv_data_store.py index f7991e245..e5094127c 100644 --- a/pdr_backend/lake/csv_data_store.py +++ b/pdr_backend/lake/csv_data_store.py @@ -3,7 +3,7 @@ # SPDX-License-Identifier: Apache-2.0 # import os -from typing import Optional, Union +from typing import Dict, Optional, Union import polars as pl from enforce_typing import enforce_types @@ -168,15 +168,34 @@ def _get_folder_path(self) -> str: return folder_path - def get_file_paths(self, do_sort=True) -> list: + def get_file_paths(self, sort_by="alpha", filter_by: Optional[Dict] = None) -> list: """ Returns the file paths for the given table name (key). """ folder_path = self._get_folder_path() file_names = os.listdir(folder_path) - if do_sort: + if sort_by == "alpha": file_names = sorted(file_names) + elif sort_by == "none": + pass + + if filter_by: + filtered_file_names = [] + + for file_name in file_names: + from_val = _get_from_value(file_name) + to_val = _get_to_value(file_name) + + if to_val is not None and to_val < filter_by["from"]: + continue + + if from_val > filter_by["to"]: + continue + + filtered_file_names.append(file_name) + + file_names = filtered_file_names return [os.path.join(folder_path, file_name) for file_name in file_names] @@ -187,7 +206,7 @@ def has_data(self) -> bool: @returns: bool - True if the csv files have data """ - file_paths = self.get_file_paths(do_sort=False) + file_paths = self.get_file_paths(sort_by="none") # check if the csv file has more than 0 bytes return any(os.path.getsize(file_path) > 0 for file_path in file_paths) @@ -252,14 +271,16 @@ def _append_remaining_rows( return data.slice(remaining_rows, len(data) - remaining_rows) - def read_all(self, schema: Optional[SchemaDict] = None) -> pl.DataFrame: + def read_all( + self, schema: Optional[SchemaDict] = None, filters: Optional[Dict] = None + ) -> pl.DataFrame: """ Reads all the data from the csv files in the folder corresponding to the given table name (key). @returns: pl.DataFrame - data read from the csv files """ - file_paths = self.get_file_paths() + file_paths = self.get_file_paths(filter_by=filters) if not file_paths: return pl.DataFrame([], schema=schema) @@ -289,7 +310,9 @@ def read( @returns: pl.DataFrame - data read from the csv file """ - data = self.read_all(schema=schema) + data = self.read_all( + schema=schema, filters={"from": start_time, "to": end_time} + ) # if the data is empty, return if len(data) == 0: return data