Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CrossEncoderModule with rerank API #389

Merged
merged 7 commits into from
Sep 12, 2024
Merged

Conversation

markstur
Copy link
Contributor

This module is closely related to EmbeddingModule.

Cross-encoder models use Q and A pairs and are trained return a relevance score for rank(). The existing rerank APIs in EmbeddingModule had to encode Q and A separately and use cosine similarity as a score. So the API is the same, but the results are supposed to be better (and slower).

Cross-encoder models do not support returning embedding vectors or sentence-similarity.

Support for the existing tokenization and model_info endpoints was also added.

This module is closely related to EmbeddingModule.

Cross-encoder models use Q and A pairs and are trained return a relevance score for rank().
The existing rerank APIs in EmbeddingModule had to encode Q and A
separately and use cosine similarity as a score. So the API is the same, but the results
are supposed to be better (and slower).

Cross-encoder models do not support returning embedding vectors or sentence-similarity.

Support for the existing tokenization and model_info endpoints was also added.

Signed-off-by: Mark Sturdevant <[email protected]>
Copy link
Collaborator

@evaline-ju evaline-ju left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some q's mainly around how much ipex will apply here, configurable parameters, and truncation result testing!

caikit_nlp/modules/text_embedding/crossencoder.py Outdated Show resolved Hide resolved
caikit_nlp/modules/text_embedding/crossencoder.py Outdated Show resolved Hide resolved
caikit_nlp/modules/text_embedding/crossencoder.py Outdated Show resolved Hide resolved
caikit_nlp/modules/text_embedding/crossencoder.py Outdated Show resolved Hide resolved


@pytest.mark.parametrize("truncate_input_tokens", [-1, 99, 510, 511, 512])
def test_too_many_tokens_with_truncation_working(truncate_input_tokens, loaded_model):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

while no errors are raised, maybe there should be at least one test to make sure the truncation leads to the expected final result (mainly to make sure the logic of the _truncation_needed function like positioning and all is tested/working as expected)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep. Needed to push the PR before I could get to that. Will add some confirming test.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

caikit_nlp/modules/text_embedding/crossencoder.py Outdated Show resolved Hide resolved
caikit_nlp/modules/text_embedding/crossencoder.py Outdated Show resolved Hide resolved
error.value_check(
"<NLP20896115E>",
artifacts_path,
ValueError(f"Model config missing '{cls._ARTIFACTS_PATH_KEY}'"),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

value_check with automatically through ValueError, you do not need to pass ValueError here. You can simple pass f"Model config missing '{cls._ARTIFACTS_PATH_KEY}'" instead

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. Fixed

if ipex:
if autocast: # IPEX performs best with autocast using bfloat16
model = ipex.optimize(
model, dtype=torch.bfloat16, weights_prepack=False
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is bfloat16 supported on all devices for ipex ? or should we make the dtype configurable somehow?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Took this out because we won't really properly test the ipex options for short-term cross-encoder needs. But FYI, various config names related to dtype/bfloat16 were rejected as confusing with other uses so for embeddings it ended up ipex + autocast is how you take advantage ipex with bloat16 speed.

caikit_nlp/modules/text_embedding/crossencoder.py Outdated Show resolved Hide resolved
caikit_nlp/modules/text_embedding/crossencoder.py Outdated Show resolved Hide resolved
self,
queries: List[str],
documents: List[JsonDict],
top_n: Optional[int] = None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: why not call this as top_k ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Long story, but short version is some folks thought that had other implications and preferred to avoid it and go with top_n. This is now in our rerank API that is shared with text-embedding models and cross-encoder models so changing it would not be great.

So today I have top_n for the API we expose but it becomes top_k as the familiar parameter name in CrossEncoder functions. Sorry. Is there a better thing to do here?

def smart_batching_collate_text_only(
self, batch, truncate_input_tokens: Optional[int] = 0
):
texts = [[] for _ in range(len(batch[0]))]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: this can be:

texts = [[]] * len(batch[0]) 

?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

your way look better and "looks" equivalent but it breaks the data fetcher. Using range works.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

interesting. Whats "data fetcher" here ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is using torch DataLoader so callables and iterators or being used. Overkill if I wrote it from scratch, but I'm using what sentence-transformers/CrossEncoder has been using as much as possible (with our extensions as needed for truncation and token counting).

Comment on lines +682 to +683
activation_fct=None,
apply_softmax=False,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
activation_fct=None,
apply_softmax=False,
activation_fct = None,
apply_softmax = False,

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nope. Our linter/formatter insists on no spaces around keyword parameter equals which is a good thing.

The odd thing is that lint/fmt rules are the opposite when there is a type.

Fortunately tox takes care of this.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh wow 🤔 thats weird

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it is weird, but I find trying to make python a typed language is generally a little weird (usually not this odd)

caikit_nlp/modules/text_embedding/crossencoder.py Outdated Show resolved Hide resolved
caikit_nlp/modules/text_embedding/crossencoder.py Outdated Show resolved Hide resolved
pred_scores = torch.stack(pred_scores)
elif convert_to_numpy:
pred_scores = np.asarray(
[score.cpu().detach().numpy() for score in pred_scores]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wondering if

score.cpu().detach().numpy() should be score.cpu().detach().float().item(), since numpy() can be an array but we want a float here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, your suggestion looks better. Done

* mostly removing unnecessary code
* some better clarity

Signed-off-by: Mark Sturdevant <[email protected]>
* The already borrowed errors are fixed with tokenizers per thread,
  so there were some misleading comments about not changing params
  for truncation (which we do for cross-encoder truncation).

Signed-off-by: Mark Sturdevant <[email protected]>
@markstur
Copy link
Contributor Author

Thanks for the reviews!

Forgot to mention regarding the removal of ipex, etc code... Part of that I was keeping to get MPS support as well, but I've found out that the default with CrossEncoder handles MPS and CUDA device already.

Default is 32.
Can override with embedding batch_size in config or EMBEDDING_BATCH_SIZE env var.

Signed-off-by: Mark Sturdevant <[email protected]>
* Moved the truncation check to a place that can determine
  the proper index for the error message (with batching).

* Added test to validate some results after truncation.
  This is with a tiny model, but works for sanity.

Signed-off-by: Mark Sturdevant <[email protected]>
The part that really tests that a token is truncated was wrong.

* It was backwards and passing because the scores are sorted by rank
* Using the index to get scores in the order of the inputs
* Now correctly xx != xy but xy == xyz (truncated z)

Signed-off-by: Mark Sturdevant <[email protected]>
Copy link
Collaborator

@evaline-ju evaline-ju left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM - thanks for the updates!

@evaline-ju evaline-ju merged commit 1695c3b into caikit:main Sep 12, 2024
5 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants