Skip to content

Commit

Permalink
[rust] Fix camembert model loading (#3418)
Browse files Browse the repository at this point in the history
  • Loading branch information
xyang16 authored Aug 15, 2024
1 parent 08194c4 commit a03a324
Showing 1 changed file with 7 additions and 2 deletions.
9 changes: 7 additions & 2 deletions extensions/tokenizers/rust/src/models/camembert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -480,12 +480,17 @@ pub struct CamembertModel {
impl CamembertModel {
pub fn load(vb: VarBuilder, config: &CamembertConfig) -> Result<Self> {
let (embeddings, encoder) = match (
BertEmbeddings::load(vb.pp("roberta.embeddings"), config),
BertEncoder::load(vb.pp("roberta.encoder"), config),
BertEmbeddings::load(vb.pp("embeddings"), config),
BertEncoder::load(vb.pp("encoder"), config),
) {
(Ok(embeddings), Ok(encoder)) => (embeddings, encoder),
(Err(err), _) | (_, Err(err)) => {
if let (Ok(embeddings), Ok(encoder)) = (
BertEmbeddings::load(vb.pp("roberta.embeddings".to_string()), config),
BertEncoder::load(vb.pp("roberta.encoder".to_string()), config),
) {
(embeddings, encoder)
} else if let (Ok(embeddings), Ok(encoder)) = (
BertEmbeddings::load(vb.pp("deberta.embeddings".to_string()), config),
BertEncoder::load(vb.pp("deberta.encoder".to_string()), config),
) {
Expand Down

0 comments on commit a03a324

Please sign in to comment.