-
Notifications
You must be signed in to change notification settings - Fork 331
/
main.rs
64 lines (59 loc) · 1.99 KB
/
main.rs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
use std::sync::Arc;
use tokio::sync::mpsc::channel;
use mistralrs::{
Constraint, Device, DeviceMapMetadata, GGUFLoaderBuilder, GGUFSpecificConfig, MistralRs,
MistralRsBuilder, NormalRequest, Request, RequestMessage, Response, SamplingParams,
SchedulerMethod, TokenSource,
};
fn setup() -> anyhow::Result<Arc<MistralRs>> {
// Select a Mistral model
// We do not use any files from HF servers here, and instead load the
// chat template from the specified file, and the tokenizer and model from a
// local GGUF file at the path `.`
let loader = GGUFLoaderBuilder::new(
GGUFSpecificConfig { repeat_last_n: 64 },
Some("chat_templates/mistral.json".to_string()),
None,
".".to_string(),
"mistral-7b-instruct-v0.1.Q4_K_M.gguf".to_string(),
)
.build();
// Load, into a Pipeline
let pipeline = loader.load_model_from_hf(
None,
TokenSource::CacheToken,
None,
&Device::cuda_if_available(0)?,
false,
DeviceMapMetadata::dummy(),
None,
)?;
// Create the MistralRs, which is a runner
Ok(MistralRsBuilder::new(pipeline, SchedulerMethod::Fixed(5.try_into().unwrap())).build())
}
fn main() -> anyhow::Result<()> {
let mistralrs = setup()?;
let (tx, mut rx) = channel(10_000);
let request = Request::Normal(NormalRequest {
messages: RequestMessage::Completion {
text: "Hello! My name is ".to_string(),
echo_prompt: false,
best_of: 1,
},
sampling_params: SamplingParams::default(),
response: tx,
return_logprobs: false,
is_streaming: false,
id: 0,
constraint: Constraint::None,
suffix: None,
adapters: None,
});
mistralrs.get_sender().blocking_send(request)?;
let response = rx.blocking_recv().unwrap();
match response {
Response::CompletionDone(c) => println!("Text: {}", c.choices[0].text),
_ => unreachable!(),
}
Ok(())
}