Skip to content

Commit

Permalink
feat: swtich cpu backend to llama.cpp (#638)
Browse files Browse the repository at this point in the history
* feat: swtich Cpu backend to llama.cpp

* feat: switch cpu serving to ggml

* fix cargo.toml

* use optional dependency

* fix compliation

* update ci target
  • Loading branch information
wsxiaoys authored Oct 25, 2023
1 parent 21ec60e commit 1a4c2aa
Show file tree
Hide file tree
Showing 9 changed files with 48 additions and 32 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ jobs:
- run: bash ./ci/prepare_build_environment.sh

- name: Bulid release binary
run: cargo build --no-default-features --release --target ${{ matrix.target }}
run: cargo build --no-default-features --release --target ${{ matrix.target }} --package tabby

- name: Rename release binary
run: mv target/${{ matrix.target }}/release/tabby tabby_${{ matrix.target }}
Expand Down
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
## Features

## Fixes and Improvements
* Switch cpu backend to llama.cpp: https://github.com/TabbyML/tabby/pull/638
* add `server.completion_timeout` to control the code completion interface timeout: https://github.com/TabbyML/tabby/pull/637

# v0.4.0

Expand Down
2 changes: 1 addition & 1 deletion crates/llama-cpp-bindings/include/engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,5 @@ class TextInferenceEngine {
virtual uint32_t eos_token() const = 0;
};

std::unique_ptr<TextInferenceEngine> create_engine(rust::Str model_path);
std::unique_ptr<TextInferenceEngine> create_engine(bool use_gpu, rust::Str model_path);
} // namespace
4 changes: 2 additions & 2 deletions crates/llama-cpp-bindings/src/engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -114,11 +114,11 @@ struct BackendInitializer {
};
} // namespace

std::unique_ptr<TextInferenceEngine> create_engine(rust::Str model_path) {
std::unique_ptr<TextInferenceEngine> create_engine(bool use_gpu, rust::Str model_path) {
static BackendInitializer initializer;

llama_model_params model_params = llama_model_default_params();
model_params.n_gpu_layers = 1;
model_params.n_gpu_layers = use_gpu ? 1 : 0;
llama_model* model = llama_load_model_from_file(std::string(model_path).c_str(), model_params);

if (!model) {
Expand Down
5 changes: 3 additions & 2 deletions crates/llama-cpp-bindings/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ mod ffi {

type TextInferenceEngine;

fn create_engine(model_path: &str) -> UniquePtr<TextInferenceEngine>;
fn create_engine(use_gpu: bool, model_path: &str) -> UniquePtr<TextInferenceEngine>;

fn start(self: Pin<&mut TextInferenceEngine>, input_token_ids: &[u32]);
fn step(self: Pin<&mut TextInferenceEngine>) -> Result<u32>;
Expand All @@ -32,6 +32,7 @@ unsafe impl Sync for ffi::TextInferenceEngine {}
pub struct LlamaEngineOptions {
model_path: String,
tokenizer_path: String,
use_gpu: bool,
}

pub struct LlamaEngine {
Expand All @@ -42,7 +43,7 @@ pub struct LlamaEngine {

impl LlamaEngine {
pub fn create(options: LlamaEngineOptions) -> Self {
let engine = create_engine(&options.model_path);
let engine = create_engine(options.use_gpu, &options.model_path);
if engine.is_null() {
panic!("Unable to load model: {}", options.model_path);
}
Expand Down
5 changes: 2 additions & 3 deletions crates/tabby/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ version = "0.5.0-dev"
edition = "2021"

[dependencies]
ctranslate2-bindings = { path = "../ctranslate2-bindings" }
tabby-common = { path = "../tabby-common" }
tabby-scheduler = { path = "../tabby-scheduler" }
tabby-download = { path = "../tabby-download" }
Expand Down Expand Up @@ -43,9 +42,8 @@ minijinja = { version = "1.0.8", features = ["loader"] }
textdistance = "1.0.2"
regex.workspace = true
thiserror.workspace = true

[target.'cfg(all(target_os="macos", target_arch="aarch64"))'.dependencies]
llama-cpp-bindings = { path = "../llama-cpp-bindings" }
ctranslate2-bindings = { path = "../ctranslate2-bindings", optional = true }

[dependencies.uuid]
version = "1.3.3"
Expand All @@ -57,6 +55,7 @@ features = [

[features]
link_shared = ["ctranslate2-bindings/link_shared"]
link_cuda_static = ["ctranslate2-bindings"]

[build-dependencies]
vergen = { version = "8.0.0", features = ["build", "git", "gitcl"] }
Expand Down
23 changes: 13 additions & 10 deletions crates/tabby/src/serve/engine.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
use std::path::Path;

use ctranslate2_bindings::{CTranslate2Engine, CTranslate2EngineOptionsBuilder};
use serde::Deserialize;
use tabby_common::path::ModelDir;
use tabby_inference::TextGeneration;
Expand Down Expand Up @@ -39,33 +38,36 @@ pub struct EngineInfo {
pub chat_template: Option<String>,
}

#[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
#[cfg(not(any(feature = "link_shared", feature = "link_cuda_static")))]
fn create_local_engine(
args: &crate::serve::ServeArgs,
model_dir: &ModelDir,
metadata: &Metadata,
_metadata: &Metadata,
) -> Box<dyn TextGeneration> {
create_ctranslate2_engine(args, model_dir, metadata)
create_ggml_engine(&args.device, model_dir)
}

#[cfg(all(target_os = "macos", target_arch = "aarch64"))]
#[cfg(any(feature = "link_shared", feature = "link_cuda_static"))]
fn create_local_engine(
args: &crate::serve::ServeArgs,
model_dir: &ModelDir,
metadata: &Metadata,
) -> Box<dyn TextGeneration> {
if args.device != super::Device::Metal {
create_ctranslate2_engine(args, model_dir, metadata)
if args.device.use_ggml_backend() {
create_ggml_engine(&args.device, model_dir)
} else {
create_llama_engine(model_dir)
create_ctranslate2_engine(args, model_dir, metadata)
}
}

#[cfg(any(feature = "link_shared", feature = "link_cuda_static"))]
fn create_ctranslate2_engine(
args: &crate::serve::ServeArgs,
model_dir: &ModelDir,
metadata: &Metadata,
) -> Box<dyn TextGeneration> {
use ctranslate2_bindings::{CTranslate2Engine, CTranslate2EngineOptionsBuilder};

let device = format!("{}", args.device);
let options = CTranslate2EngineOptionsBuilder::default()
.model_path(model_dir.ctranslate2_dir())
Expand All @@ -78,11 +80,11 @@ fn create_ctranslate2_engine(
Box::new(CTranslate2Engine::create(options))
}

#[cfg(all(target_os = "macos", target_arch = "aarch64"))]
fn create_llama_engine(model_dir: &ModelDir) -> Box<dyn TextGeneration> {
fn create_ggml_engine(device: &super::Device, model_dir: &ModelDir) -> Box<dyn TextGeneration> {
let options = llama_cpp_bindings::LlamaEngineOptionsBuilder::default()
.model_path(model_dir.ggml_q8_0_file())
.tokenizer_path(model_dir.tokenizer_file())
.use_gpu(device.ggml_use_gpu())
.build()
.unwrap();

Expand All @@ -99,6 +101,7 @@ fn get_model_dir(model: &str) -> ModelDir {

#[derive(Deserialize)]
struct Metadata {
#[allow(dead_code)]
auto_model: String,
prompt_template: Option<String>,
chat_template: Option<String>,
Expand Down
36 changes: 24 additions & 12 deletions crates/tabby/src/serve/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ pub enum Device {
#[strum(serialize = "cpu")]
Cpu,

#[strum(serialize = "cuda")]
#[cfg(any(feature = "link_shared", feature = "link_cuda_static"))]
Cuda,

#[cfg(all(target_os = "macos", target_arch = "aarch64"))]
Expand All @@ -85,6 +85,28 @@ pub enum Device {
ExperimentalHttp,
}

impl Device {
#[cfg(all(target_os = "macos", target_arch = "aarch64"))]
fn use_ggml_backend(&self) -> bool {
*self == Device::Metal || *self == Device::Cpu
}

#[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
fn use_ggml_backend(&self) -> bool {
*self == Device::Cpu
}

#[cfg(all(target_os = "macos", target_arch = "aarch64"))]
fn ggml_use_gpu(&self) -> bool {
*self == Device::Metal
}

#[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
fn ggml_use_gpu(&self) -> bool {
false
}
}

#[derive(Args)]
pub struct ServeArgs {
/// Model id for `/completions` API endpoint.
Expand Down Expand Up @@ -115,16 +137,6 @@ pub struct ServeArgs {
compute_type: Option<String>,
}

#[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
fn should_download_ggml_files(_device: &Device) -> bool {
false
}

#[cfg(all(target_os = "macos", target_arch = "aarch64"))]
fn should_download_ggml_files(device: &Device) -> bool {
*device == Device::Metal
}

pub async fn main(config: &Config, args: &ServeArgs) {
valid_args(args);

Expand Down Expand Up @@ -275,7 +287,7 @@ fn start_heartbeat(args: &ServeArgs) {
async fn download_model(model: &str, device: &Device) {
let downloader = Downloader::new(model, /* prefer_local_file= */ true);
let handler = |err| fatal!("Failed to fetch model '{}' due to '{}'", model, err,);
let download_result = if should_download_ggml_files(device) {
let download_result = if device.use_ggml_backend() {
downloader.download_ggml_files().await
} else {
downloader.download_ctranslate2_files().await
Expand Down
1 change: 0 additions & 1 deletion website/docs/models/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ We recommend using
| [TabbyML/StarCoder-7B](https://huggingface.co/TabbyML/StarCoder-7B) | [BigCode-OpenRAIL-M](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) |||
| [TabbyML/StarCoder-3B](https://huggingface.co/TabbyML/StarCoder-3B) | [BigCode-OpenRAIL-M](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) |||
| [TabbyML/StarCoder-1B](https://huggingface.co/TabbyML/StarCoder-1B) | [BigCode-OpenRAIL-M](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) |||
| [TabbyML/J-350M](https://huggingface.co/TabbyML/J-350M) | [BSD-3](https://opensource.org/license/bsd-3-clause/) |||

## Chat models (`--chat-model`)

Expand Down

0 comments on commit 1a4c2aa

Please sign in to comment.