-
Notifications
You must be signed in to change notification settings - Fork 75
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Inference API #109
Comments
What's the way to save and load a model currently? I'm saving it like so BSON.@save bsonname bert_model wordpiece tokenizer And loading it using |
Simply |
Weird. I have all the dependencies imported in the main module and I'm including the loading script in the module. Anyways, importing them in the REPL solved the issue -- probly a dumb mistake on my part. Right now, I'm doing this: struct Pipeline
bert_model
wordpiece
tokenizer
bertenc
function Pipeline(; ckpt::AbstractString="BERT_Twitter_Epochs_1")
bert_model, wordpiece, tokenizer = load_bert_pretrain("ckpt/$ckpt.bson")
bert_model = todevice(bert_model)
bertenc = BertTextEncoder(tokenizer, wordpiece)
Flux.testmode!(bert_model)
new(bert_model, wordpiece, tokenizer, bertenc)
end
end
function (p::Pipeline)(query::AbstractString)
data = todevice(preprocess([[query], ["0"]]))
e = p.bert_model.embed(data.input)
t = p.bert_model.transformers(e, data.mask)
prediction = p.bert_model.classifier.clf(
p.bert_model.classifier.pooler(
t[:,1,:]
)
)
@info "Prediction: " prediction
end I can do >p = Pipeline()
>p("this classifier sucks")
┌ Info: Prediction:
│ prediction =
│ 2×1 Matrix{Float32}:
│ -0.06848035
└ -2.7152526 I have no idea how to interpret the results (should I uhh take the absolute to know which one hot category is hot??) but is this the correct approach? |
Several points:
|
Any further thoughts on an Inference API? @ashwani-rathee and I have been discussing a framework-agnostic API - in particular for inference - that might be relevant wrt. to an inference API for Transformers.jl: https://julialang.zulipchat.com/#narrow/stream/390029-image-processing/topic/DL.20based.20tools/near/383544112 |
mentioned in #108. Currently we don't have an inference api, like the
pipeline
from huggingface transformers. Right now you need to manually load the model/tokenizer, apply them on the input data, and convert the prediction result to correct/corresponding labels.The text was updated successfully, but these errors were encountered: