Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

N-dimensional interpolation #235

Merged
merged 10 commits into from
Jun 4, 2024
Merged

N-dimensional interpolation #235

merged 10 commits into from
Jun 4, 2024

Conversation

kylecarow
Copy link
Collaborator

I've plumbed up the 0/1/2/3/N-dimensional interpolation code that I developed for FASTSim to the routee-compass-powertrain crate's ModelType::Interpolate. This is capable of taking any dimension data, with Interpolator enum variants that allow hardcoded interpolation for performance benefits (there are benchmarks in FASTSim 3 for interpolation). It supports multiple interpolation 'strategies' for 1D interpolation, but only linear for other dimensionalities.

There are more improvements to be made, as documented in #135, but this passes all tests in Rust and the single test in Python. Notably, the 2D speed/grade stuff in InterpolationSpeedGradeModel::new() is still hardcoded, so more work and thought is required to allow more dimensions and proper mapping of dimensions (and any unit considerations).

See the tests in interp.rs for examples on usage. Be sure to call .validate() after instantiating to run checks on the supplied data.

@kylecarow kylecarow added the enhancement New feature or request label May 21, 2024
@kylecarow kylecarow added this to the PyCon 2024 milestone May 21, 2024
Copy link
Collaborator

@nreinicke nreinicke left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is looking great, nice work on making the interpolation so flexible. I left a couple small comments about converting a few of the unwraps into errors.

Copy link
Collaborator

@robfitzgerald robfitzgerald left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @kylecarow, excited to see this open new doors for our powertrain models. The code is great, nicely structured. I have a few comments but have only done a quick proofread:
- following on nicks suggestion, remove any calls to .unwrap() that aren't in a test suite
- Code was copied from a fastsim pull request, we should attribute it via a code comment at the top of the file

@kylecarow
Copy link
Collaborator Author

Done with converting unwraps to results! Also switched to binary search (short circuiting if it's the point is the last element of the grid), for convenience and because it should scale better.

Finished just before my flight boarded, lol.

Copy link
Collaborator

@nreinicke nreinicke left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for making those fixes! I've tested this out and it seems to be working as intended.

Copy link
Collaborator

@robfitzgerald robfitzgerald left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One more thing I noticed that I think we want to change before we roll this in (correct me if I'm wrong, @nreinicke )

"Could not get last x-value; are x-values empty?".to_string(),
)
})?,
interp.y[0],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey guys one least question, sorry for being a bottleneck here. Using the index operator here will cause a panic if the response is not as expected. Shouldn't we be using .get instead?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah that's a good point. Ideally we would never get to this point if the interpolation grid was empty but it's probably good practice to use get instead (unless that has a performance cost associated with it). Perhaps it would be good enough to just add an additional check into the validate function that checks to make sure the grid data always has at least two points and give an instructive error at that point.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Based on some light searching it looks like there isn't really a performance difference between the direct access and the get method. So, maybe best practice would be to add a check in the validate function to make sure grid is not empty and also switch these direct access into gets?

…n methods, comments to explain where result indexing is unnecessary
@kylecarow
Copy link
Collaborator Author

kylecarow commented Jun 3, 2024

@robfitzgerald @nreinicke I've added additional checks to validation to prevent panics on bad inputs, as well as corresponding tests. I think it would be kind of impossible/impractical to code out every use of [] for indexing, things are indexed quite a lot - I have replaced it where possible, and left comments in-code where direct indexing is okay because the data is checked by validation methods. Most of the remaining indexing should be outputs from find_nearest_index, so if there are panics with indexing in the future, it means there is likely a bug in that function, which should already be pretty robust.

Replacing [] with gets would be pretty onerous in spots like this:

// interpolate in the x-direction
let c00 = self.f_xyz[x_l][y_l][z_l] * (1.0 - x_diff) + self.f_xyz[x_u][y_l][z_l] * x_diff;
let c01 = self.f_xyz[x_l][y_l][z_u] * (1.0 - x_diff) + self.f_xyz[x_u][y_l][z_u] * x_diff;
let c10 = self.f_xyz[x_l][y_u][z_l] * (1.0 - x_diff) + self.f_xyz[x_u][y_u][z_l] * x_diff;
let c11 = self.f_xyz[x_l][y_u][z_u] * (1.0 - x_diff) + self.f_xyz[x_u][y_u][z_u] * x_diff;

@nreinicke
Copy link
Collaborator

Nice, thanks for adding the additional checks in. I'm on board with your assessment about not modifying every index as long as we have the appropriate validation check at load time. I'll leave the floor open to @robfitzgerald for any other comments but overall this looks good to me.

@robfitzgerald
Copy link
Collaborator

@kylecarow i'm sorry i gotta be a stickler on this since this is in the core lib, so i would prefer we don't pass these potential panics onto users if they do something unexpected. but i think this could be a quick fix, it's a good opportunity for writing helper functions that can encapsulate the error handling. that should split the difference for us wrt conciseness and performance. i may have this wrong but here's a guess on what those signatures would look like:

fn interp1(f: Vec<f64>, x_diff: f64, l: usize, u: usize) -> Result<f64, String> { todo!() }
fn interp2(f: Vec<Vec<f64>, x_diff: f64, l: (usize, usize), u: (usize, usize)) -> Result<f64, String> { todo!() }
fn interp3(f: Vec<Vec<Vec<f64>>>, x_diff: f64, l: (usize, usize, usize), u: (usize, usize, usize)) -> Result<f64, String> { todo!() }
fn interpN(f: ArrayD<f64>, x_diff: f64, l: Vec<usize>, u: Vec<usize>) -> Result<f64, String> { todo!() }
let c00 = interp3(self.f_xyz, x_diff, (x_l, y_l, z_l), (x_u, y_l, z_l))?; 

or if you don't like that, i could also imagine just array lookup helpers like this:

fn arr3(f: Vec<Vec<Vec<f64>>>, x: usize, y: usize, z: usize) -> Result<f64, String> { todo!() }

that could be used like this:

let c00 = arr3(self.f_xyz, x_l, y_l, z_l)? * (1.0 - x_diff) + arr3(self.f_xyz(x_u, y_l, z_l)? * x_diff; 

or anything else you think could make this not be super annoying to implement!

@robfitzgerald robfitzgerald self-requested a review June 3, 2024 20:30
Copy link
Collaborator

@robfitzgerald robfitzgerald left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ack, i wish i wasn't on leave, i just can't devote my attention to this the way i would like to, which makes me an evil online PR villain. realizing that i was missing the point here about the fact that validate is called every time. while that does sound expensive i think we are in good shape for now. sorry for being a nuisance!

@nreinicke
Copy link
Collaborator

No, this is good discussion, you're definitely not being a nuisance! I think the concept is that the validate method gets called just once when the interpolator is getting built which gives us our guarantees about the index bounds.

@robfitzgerald
Copy link
Collaborator

I think the concept is that the validate method gets called just once when the interpolator is getting built which gives us our guarantees about the index bounds.

Well, in that case, I would love it if we had a Interpolator::new method that calls validate internally. This gets out of the OO-ish framing of "there's a method you need to know gets called first to use this." But also recognizing that the new method would need to take an ArrayD and then do the work to specialize that down to 1/2/3D arrays when appropriate in the new method. Could we make an issue to come back to that?

@kylecarow
Copy link
Collaborator Author

validate is called every time. while that does sound expensive

@robfitzgerald
No worries, here's a description of the data validation that Nick and I had discussed in-person at PyCon:

There are two validation methods:

  • The first method validates the data that does not change on calls of the interpolator, i.e. the grid coordinates and corresponding values.
    • Example: checking the coordinates are monotonically increasing, that data isn't empty, etc.
    • This only runs when the interpolator is created, e.g. with InterpolationSpeedGradeModel::new(), I imagine only once
  • The second method validates data supplied in the actual .interpolate call, i.e. the interpolant point.

Does that help clear up the data validation process? If this still seems expensive I'm open to suggestions for improvements :)

@kylecarow
Copy link
Collaborator Author

kylecarow commented Jun 3, 2024

It looks like we're all in a comment data race!

@robfitzgerald I had wanted an Interpolator::new method that calls validate(), but struggled with it due to the way the interpolator variants need to be instantiated, it would be hard to create an function signature that is clear, concise, and allows passage of the data needed for any variant. Your ArrayD approach is an interesting way around that!

A thought that just occurred to me is to have a new method for the variant structs, rather than the Interpolator enum itself. For example, creating a 2D interpolator would change like this:

- let interp = Interpolator::Interp2D(Interp2D {
-     x: vec![0.05, 0.10, 0.15],
-     y: vec![0.10, 0.20, 0.30],
-     f_xy: vec![vec![0., 1., 2.], vec![3., 4., 5.], vec![6., 7., 8.]],
- });
- interp.validate()?;
+ let interp = Interpolator::Interp2D(Interp2D::new(
+     vec![0.05, 0.10, 0.15],
+     vec![0.10, 0.20, 0.30],
+     vec![vec![0., 1., 2.], vec![3., 4., 5.], vec![6., 7., 8.]],
+ )?);

We could then restrict users from creating the variant structs directly (i.e. via writing Interp2D { ... } instead of calling Interp2D::new(...)) by adding a phantom non-pub field to the variant structs, as described in number 2 here: https://stackoverflow.com/a/77683321/11278044

If we make it impossible to create an Interpolator without Interpolator::validate() being called, there are some ok_or_else calls that could be changed back to unwraps (they should never panic), but I don't know if we would want to do that, or even if we care about doing that.

@robfitzgerald
Copy link
Collaborator

robfitzgerald commented Jun 4, 2024

nice work, @kylecarow. while using PhantomData feels a little spooky 👻 i think it seems it's the smoothest way to hack these default enum constructors to be private and enforce our validation at construction time. nice work! and we can always add the more general Interpolator::new method which takes an arbitrary ArrayD and specializes down to 1/2/3D at a later time, and it would just use the constructors you wrote. thanks!

addendum

i was trying to think why we haven't used this solution before, given that we rely heavily on patterns using rust enums. i realized that, when an enum requires a custom constructor, i have created a matching enum that has the constructor assets which then builds the "runtime" enum at some initialization phase. so, for example, Runtime::A builds a Vec<Vec> but we want to build it from a file, which may fail, so, Config::A has a String and builds a Runtime via Config::build:

pub enum Config {
  A { file: String },
  B
}

impl Config {
  pub fn build(&self) -> Result<Runtime, String> {
    match self {
      A { file } => Ok(Runtime { table: read_table(file)! }),
      B => Ok(Runtime::B)
    }
  }
}

pub enum Runtime {
  A { table: Vec<Vec<f64>>},
  B
}

this doesn't prevent users from directly building Runtimes, but it's a pretty clear separation of concerns, representing the changes in state due to initialization via the type system, at least removing ambiguity about whether the runtime is ever "valid" has ever been validated.

another addendum

i guess this is were i still part ways from this implementation. the user could still directly build a runtime, and, that's up to them. and that's why i would still never assume i can index a vector directly. but i think we can save that kind of nitpickery for a later date, i still think this is good to go!

@kylecarow
Copy link
Collaborator Author

kylecarow commented Jun 4, 2024

Just pushed a minor commit to remove a superfluous case in output handling from find_nearest_index and a very minor formatting tweak. It should be good to merge whenever, so long as you both still approve.

And the notes above about the need to rework some code higher up in the call chain to actually be able to use features other than speed and grade still applies.

edit:

I'll also note that Nick fielded the idea of some sort of implementation for 'clipping' within interpolation itself, which we currently do here:

let (min_speed, max_speed, min_grade, max_grade) = match &self.interpolator {
interp::Interpolator::Interp2D(interp) => (
*interp.x.first().ok_or_else(|| {
TraversalModelError::PredictionModel(
"Could not get first x-value; are x-values empty?".to_string(),
)
})?,
*interp.x.last().ok_or_else(|| {
TraversalModelError::PredictionModel(
"Could not get last x-value; are x-values empty?".to_string(),
)
})?,
*interp.y.first().ok_or_else(|| {
TraversalModelError::PredictionModel(
"Could not get first y-value; are y-values empty?".to_string(),
)
})?,
*interp.y.last().ok_or_else(|| {
TraversalModelError::PredictionModel(
"Could not get last y-value; are y-values empty?".to_string(),
)
})?,
),
_ => {
return Err(TraversalModelError::PredictionModel(
"Only 2-D interpolators are currently supported".to_string(),
))
}
};
let speed_value = speed_value.max(min_speed).min(max_speed);
let grade_value = grade_value.max(min_grade).min(max_grade);

@nreinicke nreinicke merged commit 6ec188b into main Jun 4, 2024
5 checks passed
@nreinicke nreinicke deleted the kjc/interpolation branch June 4, 2024 18:35
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants