Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(loaders): Add loader function for RPH and flexible file gathering #147

Merged
merged 5 commits into from
Aug 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/gnatss/constants/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,9 @@
GPS_COV_ZY,
GPS_COV_ZZ,
] # Covariance matrix columns

# Roll Pitch Heading columns
RPH_TIME = TIME_J2000
RPH_ROLL = "roll"
RPH_PITCH = "pitch"
RPH_HEADING = "heading"
34 changes: 34 additions & 0 deletions src/gnatss/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,40 @@ def load_travel_times(
return all_travel_times


def load_roll_pitch_heading(files: List[str]) -> pd.DataFrame:
"""
Loads roll pitch heading data into a pandas dataframe from a list of files.

Parameters
----------
files : List[str]
The list of path string to files to load

Returns
-------
pd.DataFrame
Pandas DataFrame containing all of
the roll pitch heading data.
Expected columns will have 'time' and
the 'roll', 'pitch', 'heading' values
"""
columns = [
constants.RPH_TIME,
constants.RPH_ROLL,
constants.RPH_PITCH,
constants.RPH_HEADING,
]
# Read all rph files
rph_dfs = [
pd.read_csv(i, delim_whitespace=True, header=None, names=columns)
.drop_duplicates(constants.RPH_TIME)
.reset_index(drop=True)
for i in files
]
all_rph = pd.concat(rph_dfs).reset_index(drop=True)
return all_rph


def load_gps_solutions(
files: List[str], time_round: int = constants.DELAY_TIME_PRECISION
) -> pd.DataFrame:
Expand Down
13 changes: 10 additions & 3 deletions src/gnatss/main.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import warnings
from typing import Any, Dict, List, Tuple, Union
from typing import Any, Dict, List, Literal, Tuple, Union

import numpy as np
import pandas as pd
Expand All @@ -24,7 +24,9 @@
from .utilities.io import _get_filesystem


def gather_files(config: Configuration) -> Dict[str, Any]:
def gather_files(
config: Configuration, proc: Literal["solver", "posfilter"] = "solver"
) -> Dict[str, Any]:
"""Gather file paths for the various dataset files

Parameters
Expand All @@ -38,7 +40,12 @@ def gather_files(config: Configuration) -> Dict[str, Any]:
A dictionary containing the various datasets file paths
"""
all_files_dict = {}
for k, v in config.solver.input_files.dict().items():
# Check for process type first
if not hasattr(config, proc):
raise AttributeError(f"Unknown process type: {proc}")

proc_config = getattr(config, proc)
for k, v in proc_config.input_files.dict().items():
path = v.get("path", "")
typer.echo(f"Gathering {k} at {path}")
storage_options = v.get("storage_options", {})
Expand Down
50 changes: 38 additions & 12 deletions tests/test_main.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,18 @@
import pytest

from gnatss.main import gather_files


def test_gather_files(mocker):
@pytest.mark.parametrize("proc", ["solver", "posfilter", "random"])
def test_gather_files(mocker, proc):
tt = "travel_times"
rph = "roll_pitch_heading"
glob_vals = [tt, rph]
expected_procs = {
"solver": ["sound_speed", tt, "gps_solution", "deletions"],
"posfilter": [rph],
}

# Setup get_filesystem mock
glob_res = [
"/some/path/to/1",
Expand All @@ -16,25 +27,40 @@ def glob(path):
mocker.patch("gnatss.main._get_filesystem", return_value=Filesystem)

# Setup mock configuration
item_keys = ["sound_speed", "travel_times", "gps_solution", "deletions"]
item_keys = []
if proc in expected_procs:
item_keys = expected_procs[proc]

sample_dict = {
k: {
"path": f"/some/path/to/{k}"
if k != "travel_times"
if k not in glob_vals
else "/some/glob/**/path",
"storage_options": {},
}
for k in item_keys
}
config = mocker.patch("gnatss.configs.main.Configuration")
config.solver.input_files.dict.return_value = sample_dict
if proc in list(expected_procs.keys()):
# Test for actual proc that exists
getattr(config, proc).input_files.dict.return_value = sample_dict

all_files_dict = gather_files(config, proc=proc)
# Check all_files_dict
assert isinstance(all_files_dict, dict)
assert sorted(list(all_files_dict.keys())) == sorted(item_keys)

# Test glob
for val in glob_vals:
if val in all_files_dict:
assert isinstance(all_files_dict[val], list)
assert all_files_dict[val] == glob_res
else:
# Test for random
del config.random

# Perform test
all_files_dict = gather_files(config)
# Check all_files_dict
assert isinstance(all_files_dict, dict)
assert sorted(list(all_files_dict.keys())) == sorted(item_keys)
with pytest.raises(AttributeError) as exc_info:
all_files_dict = gather_files(config, proc=proc)

# Test glob
assert isinstance(all_files_dict["travel_times"], list)
assert all_files_dict["travel_times"] == glob_res
assert exc_info.type == AttributeError
assert exc_info.value.args[0] == f"Unknown process type: {proc}"