Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding a Deep Nearest Class Means Classifier model to Flair #3532

Open
wants to merge 2 commits into
base: master
Choose a base branch
from

Conversation

sheldon-roberts
Copy link
Contributor

This PR adds a DeepNCMClassifier to flair.models
My reasons for adding this model are outlined in the issue: #3531

This model requires a TrainerPlugin because it makes the prototype updates using an after_training_batch hook. Please let me know if there is a cleaner way to handle this.

Example Script:

from flair.data import Corpus
from flair.datasets import TREC_50
from flair.embeddings import TransformerDocumentEmbeddings
from flair.models import DeepNCMClassifier
from flair.trainers import ModelTrainer
from flair.trainers.plugins import DeepNCMPlugin

# load the TREC dataset
corpus: Corpus = TREC_50()

# make a transformer document embedding
document_embeddings = TransformerDocumentEmbeddings("roberta-base", fine_tune=True)

# create the classifier
classifier = DeepNCMClassifier(
    document_embeddings,
    label_dictionary=corpus.make_label_dictionary(label_type="class"),
    label_type="class",
    use_encoder=False,
    mean_update_method="condensation",
)

# initialize the trainer
trainer = ModelTrainer(classifier, corpus)

# train the model
trainer.fine_tune(
    "resources/taggers/deepncm_trec",
    plugins=[DeepNCMPlugin()],
)

@plonerma
Copy link
Collaborator

plonerma commented Aug 19, 2024

Hello @sheldon-roberts,

Thanks a lot for your contribution! This is had been buried deep in the backlog of things to implement.

I also don't see a way of how this could be implemented without a TrainerPlugin.

What do you think about implementing this as a decoder (such as the PrototypicalDecoder), such that it can be used with the default classifier? Then it could be used with all model types (i.e. span, text, etc. classification).

Additionally, what do you think about supporting the different distance functions similar to the PrototypicalDecoder?

@sheldon-roberts
Copy link
Contributor Author

Hi @plonerma, Thanks for taking a look!

What do you think about implementing this as a decoder (such as the PrototypicalDecoder), such that it can be used with the default classifier? Then it could be used with all model types (i.e. span, text, etc. classification).
Additionally, what do you think about supporting the different distance functions similar to the PrototypicalDecoder?

I really like both of these ideas! I will look into making these changes soon

@MattGPT-ai
Copy link
Contributor

Hello @sheldon-roberts,

Thanks a lot for your contribution! This is had been buried deep in the backlog of things to implement.

I also don't see a way of how this could be implemented without a TrainerPlugin.

What do you think about implementing this as a decoder (such as the PrototypicalDecoder), such that it can be used with the default classifier? Then it could be used with all model types (i.e. span, text, etc. classification).

Additionally, what do you think about supporting the different distance functions similar to the PrototypicalDecoder?

In order to avoid using a trainer plugin, could we just add a function like def after_training_epoch(): pass that gets added to the base Model class, which gets called right before or after self.dispatch("after_training_epoch", epoch=epoch) in the train_custom function?

I think this would work with this being a class, but might not work when it gets changed to a decoder.

@MattGPT-ai
Copy link
Contributor

I am currently working on converting this class to a simpler decoder. I have gotten it to work, but it requires some changes to other classes; the label tensors have to be provided to the forward passes so they can go into the decoder call. Specifically, in DefaultClassifier.forward_loss, you need to have scores = self.decoder(data_point_tensor, label_tensor). In predict, this isn't necessary because you don't need to calculate the proto updates.

Would it make sense to always pass in this in, but just have most base cases ignore the parameter? Another alternative would be to have the class set self.label_tensor before the call so it doesn't need to be an input param at all. Not sure if anyone else has a suggestion of how to design this. I will be pushing up the specific code soon, but just looking for opinions.

@MattGPT-ai MattGPT-ai force-pushed the deepncm-classifier branch 2 times, most recently from 540e00b to c92f501 Compare November 24, 2024 00:04
Add tests for DeepNCMClassifier

Remove old test

Add multi label support

Add type hints and doc strings
@MattGPT-ai
Copy link
Contributor

This has

Hello @sheldon-roberts,

Thanks a lot for your contribution! This is had been buried deep in the backlog of things to implement.

I also don't see a way of how this could be implemented without a TrainerPlugin.

What do you think about implementing this as a decoder (such as the PrototypicalDecoder), such that it can be used with the default classifier? Then it could be used with all model types (i.e. span, text, etc. classification).

Additionally, what do you think about supporting the different distance functions similar to the PrototypicalDecoder?

This has been updated to be a decoder. It's overall a lot less code and simpler, although it required some small changes to the DefaultClassifier class, and still requires a plugin. Am definitely open to any suggestion of how to better integrate this.

…ifferent model types. make small changes to DefaultClassifier forward_loss to pass label tensor when needed. update tests
@MattGPT-ai
Copy link
Contributor

Looks like tests are passing except for a couple of MyPy checks that aren't directly related to the changes in the PR, I think just files that this PR touches. Do you have any suggestions for fixing these typing problems?

@MattGPT-ai
Copy link
Contributor

Would it be better to move this class into flair/nn/decoder.py now that it is a decoder?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants