-
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adding a Deep Nearest Class Means Classifier model to Flair #3532
base: master
Are you sure you want to change the base?
Adding a Deep Nearest Class Means Classifier model to Flair #3532
Conversation
Hello @sheldon-roberts, Thanks a lot for your contribution! This is had been buried deep in the backlog of things to implement. I also don't see a way of how this could be implemented without a What do you think about implementing this as a decoder (such as the Additionally, what do you think about supporting the different distance functions similar to the |
Hi @plonerma, Thanks for taking a look!
I really like both of these ideas! I will look into making these changes soon |
In order to avoid using a trainer plugin, could we just add a function like I think this would work with this being a class, but might not work when it gets changed to a decoder. |
I am currently working on converting this class to a simpler decoder. I have gotten it to work, but it requires some changes to other classes; the label tensors have to be provided to the forward passes so they can go into the decoder call. Specifically, in Would it make sense to always pass in this in, but just have most base cases ignore the parameter? Another alternative would be to have the class set |
540e00b
to
c92f501
Compare
Add tests for DeepNCMClassifier Remove old test Add multi label support Add type hints and doc strings
c92f501
to
b19e700
Compare
This has
This has been updated to be a decoder. It's overall a lot less code and simpler, although it required some small changes to the |
b19e700
to
088aac0
Compare
…ifferent model types. make small changes to DefaultClassifier forward_loss to pass label tensor when needed. update tests
088aac0
to
f94a56f
Compare
Looks like tests are passing except for a couple of MyPy checks that aren't directly related to the changes in the PR, I think just files that this PR touches. Do you have any suggestions for fixing these typing problems? |
Would it be better to move this class into |
This PR adds a
DeepNCMClassifier
to flair.modelsMy reasons for adding this model are outlined in the issue: #3531
This model requires a
TrainerPlugin
because it makes the prototype updates using anafter_training_batch
hook. Please let me know if there is a cleaner way to handle this.Example Script: