-
Notifications
You must be signed in to change notification settings - Fork 6.8k
Extend Clojure BERT example #15023
Extend Clojure BERT example #15023
Conversation
This provides an entry point for folks working on this example in their REPL rather than the command line.
@mxnet-label-bot add [Clojure, pr-work-in-progress] |
Thanks @daveliepmann - Looks great so far! is there a way to pass the fine-tuned model directly to the infer API, rather than creating a factory over a saved checkpoint? No - there currently isn't a way to do that. It's a good idea to investigate :) is my interpretation of results correct? I included some individual samples that surprised me.
We don't have a validation set only a training set that is going to affect the fine-tuning. We are also only running it for 3 epochs, we only have a training accuracy of about 0.70
In general, I think it's a great addition and I am happy to see it come along :) |
{ | ||
"data": { | ||
"text/plain": [ | ||
"[0.2633881 0.7366119]" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what's this the output of?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[0.2633881 0.7366119]
is the output of our sample sentence pair equivalence prediction:
(predict-equivalence fine-tuned-predictor
"The company cut spending to compensate for weak sales ."
"In response to poor sales results, the company cut spending .")
I'm not sure why the result appears before its expression in the .ipynb file, but on my machine it displays this pair correctly as "In [22]" followed by "Out [22]".
;; "69792" | ||
;; "Cisco pared spending to compensate for sluggish sales ." | ||
;; "In response to sluggish sales , Cisco pared spending ."] | ||
(predict-equivalence fine-tuned-predictor |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
did you want to add a test for this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just pushed one. Thanks for the idea—my PR broke the existing test, but I guess tests for examples aren't part of the CI checks.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cc @Chouffe @hellonico (can you help take a look as well?)
@daveliepmann can you rebase your code? i think that should hopefully fix the ci failures. also i think there's a way to perform prediction without saving the model checkpoint but it's not well documented or straightforward AFAIK. it'll be good to have this (even if it's not typically used). |
…d-clojure-bert-example
Will take a look this week @kedarbellare. Thanks for your contribution @daveliepmann :) |
contrib/clojure-package/examples/bert/src/bert/bert_sentence_classification.clj
Show resolved
Hide resolved
:aux-params (m/aux-params bert-base) | ||
:optimizer (optimizer/adam {:learning-rate 5e-6 :epsilon 1e-9}) | ||
:batch-end-callback (callback/speedometer batch-size 1)})})] | ||
(m/save-checkpoint fitted-model {:prefix fine-tuned-prefix :epoch num-epoch}) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we save the model to disk now? Could we pass in a parameter to the function to to it? This function seems to do too many things?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We have to save the model to disk because there's no other way but saving to disk and loading it back in to get a prediction out of it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
With the infer
API I suppose? Maybe we could change this at some point @kedarbellare?
contrib/clojure-package/examples/bert/src/bert/bert_sentence_classification.clj
Show resolved
Hide resolved
[{:name "data0" :shape [1 seq-length] :dtype dtype/FLOAT32 :layout layout/NT} | ||
{:name "data1" :shape [1 seq-length] :dtype dtype/FLOAT32 :layout layout/NT} | ||
{:name "data2" :shape [1] :dtype dtype/FLOAT32 :layout layout/N}]) | ||
{:epoch 3})) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we need this hardcoded epoch number here? Can't we just use num-epoch
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As I recall, we have to hard-code the epoch because otherwise we don't know which saved model to load from disk.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Responding to edit: num-epoch
isn't in scope in the comment
. I decided against defining it globally in order to parameterize a short REPL exploration.
Another reason not to def
the 3
here is that num-epoch
is a value meant to be passed in from the command line, and the rich comment code is a parallel of invoking from the command line with that argument. So at a minimum we would need a new name.
contrib/clojure-package/examples/infer/predictor/src/infer/predictor_example.clj
Show resolved
Hide resolved
contrib/clojure-package/examples/bert/src/bert/bert_sentence_classification.clj
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot @daveliepmann for making the BERT example nicer! I left some comments.
Underlying fn was refactored
thanks @daveliepmann!! 💯 |
Description
This PR extends the BERT sentence pair example in Clojure to include testing of the fine-tuned model on individual sentence pair samples. I discussed this last week with @gigasquid on Slack.
All the important changes are in the rich comment at the bottom of
bert_sentence_classification
, and are intended to be explored with a REPL. I plan to copy the REPL-driven example to the iPython notebook example after I'm sure my approach is correct but before merging.I think my approach is correct but would like to double-check the following before merging:
infer
API, rather than creating a factory over a saved checkpoint?Checklist
Essentials
Please feel free to remove inapplicable items for your PR.
[ ] The PR title starts with [MXNET-$JIRA_ID], where $JIRA_ID refers to the relevant JIRA issue created (except PRs with tiny changes)- [ ] All changes have test coverage:README.md isinline comments are added to explain the what the example does, the source of the dataset, expected performance on test set and reference to the original paper if applicableChanges
Comments