-
Notifications
You must be signed in to change notification settings - Fork 147
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add model and adapter headings #220
Conversation
router/src/server.rs
Outdated
@@ -397,6 +404,15 @@ async fn generate( | |||
time_per_token.as_millis().to_string().parse().unwrap(), | |||
); | |||
|
|||
headers.insert( | |||
"x-predibase-model-id", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: Might be a bit strange to hard-code predibase-specific things into an open-source library.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agree. Can we just return the base model name / size, and the adapter name passed in request?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sure can just make it model-id and adapter-id
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Convection is still to use X for custom headers so would something like x-lorax-model=pb://model-name
make sense and cater for HF models too?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This would suffice for immediate needs, but how would this work for HF hosted models not loaded from PB?
This is agnostic to model source but we can also make sure we pass a model source header so that you can infer the pb internals |
router/src/server.rs
Outdated
@@ -286,6 +292,7 @@ async fn generate( | |||
} | |||
|
|||
let details = req.0.parameters.details || req.0.parameters.decoder_input_details; | |||
let adapter_id = req.0.parameters.adapter_id.clone(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note that following #212, adapter ID can come from the merged_parameters
as well as the adapter_id
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good point
router/src/server.rs
Outdated
@@ -31,6 +33,10 @@ use tracing::{info_span, instrument, Instrument}; | |||
use utoipa::OpenApi; | |||
use utoipa_swagger_ui::SwaggerUi; | |||
|
|||
lazy_static! { | |||
static ref MODEL_ID: Mutex<String> = Mutex::new("".to_string()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What does this do?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It smells like we're setting a global lock on the current model ID, which would break concurrency?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It allows you to create a global var that we then update here with the model name. Needed so it is not read only.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see, but why do we need to lock every time we read it if it's only initialized once at the beginning and then read-only? Feels like we're taking a concurrency hit for no reason.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@tgaddair I think we would need to lock a global var but I could be wrong here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually looks like I can just use this https://docs.rs/once_cell/latest/once_cell/
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like exactly what we want! Nice find.
router/src/infer.rs
Outdated
adapter_ids: vec![adapter_id.clone().unwrap()], | ||
..Default::default() | ||
}); | ||
let (adapter_id, adapter_source, adapter_parameters) = extract_adapter_params( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we may want to just discard adapter_id at this point, as it should b contained in the adapter parameters.
So I would imagine this function just returning adapter_source
and adapter_parameters
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh yeah nice!
No description provided.