-
Notifications
You must be signed in to change notification settings - Fork 40
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Added a new Truss example for the Databricks DBRX-Instruct model. Cre…
…ated the directory structure, populated the config.yaml file, implemented the model loading and prediction code, and added a README. Also ran validation checks and unit tests.
- Loading branch information
Droid
committed
Apr 17, 2024
1 parent
9510193
commit 88f33e4
Showing
4 changed files
with
66 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
# Databricks DBRX Instruct Truss | ||
|
||
This Truss packages the DBRX-Instruct model from Databricks. DBRX-Instruct is an instruction-following language model that can be used for various language tasks. | ||
|
||
## Deploying | ||
|
||
To deploy this model using Truss, follow these steps: | ||
|
||
1. Clone this repo | ||
2. Set up a Baseten account and install the Truss CLI | ||
3. Run `truss deploy` to deploy the model on Baseten | ||
|
||
## Model Overview | ||
|
||
// TODO: Add a brief overview of the DBRX-Instruct model and its key capabilities | ||
|
||
## API Documentation | ||
|
||
// TODO: Document the key API endpoints, request parameters, and response format | ||
|
||
## Example Usage | ||
|
||
// TODO: Provide example code snippets demonstrating how to use the deployed model via its API |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
# Empty file |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
# Empty file |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
import logging | ||
|
||
from transformers import AutoModelForCausalLM, AutoTokenizer | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class Model: | ||
def __init__(self, model_name="databricks/dbrx-instruct") -> None: | ||
self.model_name = model_name | ||
self.model = None | ||
self.tokenizer = None | ||
|
||
def load(self): | ||
try: | ||
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) | ||
self.model = AutoModelForCausalLM.from_pretrained(self.model_name) | ||
except Exception as e: | ||
logger.error(f"Failed to load model {self.model_name}: {e}") | ||
raise | ||
|
||
def preprocess(self, request: dict) -> dict: | ||
prompt = request.get("prompt", "") | ||
return {"input_ids": self.tokenizer.encode(prompt, return_tensors="pt")} | ||
|
||
def postprocess(self, output) -> dict: | ||
return { | ||
"generated_text": self.tokenizer.decode(output[0], skip_special_tokens=True) | ||
} | ||
|
||
def predict(self, request: dict) -> dict: | ||
try: | ||
processed_input = self.preprocess(request) | ||
output = self.model.generate(**processed_input) | ||
return self.postprocess(output) | ||
except Exception as e: | ||
logger.error(f"Prediction failed: {e}") | ||
raise | ||
|
||
|
||
# Empty file |