Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Inference API #109

Open
chengchingwen opened this issue Aug 6, 2022 · 5 comments
Open

Inference API #109

chengchingwen opened this issue Aug 6, 2022 · 5 comments
Labels
enhancement New feature or request

Comments

@chengchingwen
Copy link
Owner

mentioned in #108. Currently we don't have an inference api, like the pipeline from huggingface transformers. Right now you need to manually load the model/tokenizer, apply them on the input data, and convert the prediction result to correct/corresponding labels.

@chengchingwen chengchingwen added the enhancement New feature or request label Aug 6, 2022
@Broever101
Copy link

What's the way to save and load a model currently? I'm saving it like so

BSON.@save bsonname bert_model wordpiece tokenizer

And loading it using load_pretrain_bert(bsonname) but it throws ERROR: UndefVarError: Transformers not defined while loading the tokenizer. Moreover, Flux docs suggest you should do cpu(model) before saving it -- do you think that breaks anything?

@chengchingwen
Copy link
Owner Author

Simply BSON.@save and BSON.@load. I guess the error is probably because you forget to using Transformers before loading. And yes it's better to do cpu(model) before saving.

@Broever101
Copy link

Broever101 commented Aug 6, 2022

Simply BSON.@save and BSON.@load. I guess the error is probably because you forget to using Transformers before loading. And yes it's better to do cpu(model) before saving.

Weird. I have all the dependencies imported in the main module and I'm including the loading script in the module. Anyways, importing them in the REPL solved the issue -- probly a dumb mistake on my part.

Right now, I'm doing this:

struct Pipeline 
    bert_model
    wordpiece
    tokenizer
    bertenc
    function Pipeline(; ckpt::AbstractString="BERT_Twitter_Epochs_1")
        bert_model, wordpiece, tokenizer = load_bert_pretrain("ckpt/$ckpt.bson")
        bert_model = todevice(bert_model)

        bertenc = BertTextEncoder(tokenizer, wordpiece)
        Flux.testmode!(bert_model)
        new(bert_model, wordpiece, tokenizer, bertenc)
    end
end

function (p::Pipeline)(query::AbstractString)
    data = todevice(preprocess([[query], ["0"]]))
    e = p.bert_model.embed(data.input)
    t = p.bert_model.transformers(e, data.mask)

    prediction = p.bert_model.classifier.clf(
        p.bert_model.classifier.pooler(
            t[:,1,:]
        )
    )

    @info "Prediction: " prediction
end

I can do

>p = Pipeline()
>p("this classifier sucks")
┌ Info: Prediction:
│   prediction =2×1 Matrix{Float32}:-0.06848035-2.7152526

I have no idea how to interpret the results (should I uhh take the absolute to know which one hot category is hot??) but is this the correct approach?

@chengchingwen
Copy link
Owner Author

Several points:

  1. you don't need to use load_bert_pretrain, you can just use BSON.@load.
  2. BertTextEncoder contains both tokenizer and wordpiece, so you don't need to store all of them.
  3. you would need to do Flux.onecold(prediction) to turn the logits into the index of label.
  4. but the meaning of label is missing here, so you might want to store them in your checkpoint file as well.

@stemann
Copy link

stemann commented Aug 11, 2023

Any further thoughts on an Inference API?

@ashwani-rathee and I have been discussing a framework-agnostic API - in particular for inference - that might be relevant wrt. to an inference API for Transformers.jl: https://julialang.zulipchat.com/#narrow/stream/390029-image-processing/topic/DL.20based.20tools/near/383544112

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

3 participants