-
Notifications
You must be signed in to change notification settings - Fork 50
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
CrossEncoderModule with rerank API #389
Conversation
This module is closely related to EmbeddingModule. Cross-encoder models use Q and A pairs and are trained return a relevance score for rank(). The existing rerank APIs in EmbeddingModule had to encode Q and A separately and use cosine similarity as a score. So the API is the same, but the results are supposed to be better (and slower). Cross-encoder models do not support returning embedding vectors or sentence-similarity. Support for the existing tokenization and model_info endpoints was also added. Signed-off-by: Mark Sturdevant <[email protected]>
c118db3
to
5b0989f
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some q's mainly around how much ipex will apply here, configurable parameters, and truncation result testing!
|
||
|
||
@pytest.mark.parametrize("truncate_input_tokens", [-1, 99, 510, 511, 512]) | ||
def test_too_many_tokens_with_truncation_working(truncate_input_tokens, loaded_model): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
while no errors are raised, maybe there should be at least one test to make sure the truncation leads to the expected final result (mainly to make sure the logic of the _truncation_needed
function like positioning and all is tested/working as expected)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep. Needed to push the PR before I could get to that. Will add some confirming test.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
error.value_check( | ||
"<NLP20896115E>", | ||
artifacts_path, | ||
ValueError(f"Model config missing '{cls._ARTIFACTS_PATH_KEY}'"), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
value_check
with automatically through ValueError
, you do not need to pass ValueError
here. You can simple pass f"Model config missing '{cls._ARTIFACTS_PATH_KEY}'"
instead
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks. Fixed
if ipex: | ||
if autocast: # IPEX performs best with autocast using bfloat16 | ||
model = ipex.optimize( | ||
model, dtype=torch.bfloat16, weights_prepack=False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is bfloat16
supported on all devices for ipex ? or should we make the dtype
configurable somehow?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Took this out because we won't really properly test the ipex options for short-term cross-encoder needs. But FYI, various config names related to dtype/bfloat16 were rejected as confusing with other uses so for embeddings it ended up ipex + autocast is how you take advantage ipex with bloat16 speed.
self, | ||
queries: List[str], | ||
documents: List[JsonDict], | ||
top_n: Optional[int] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: why not call this as top_k
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Long story, but short version is some folks thought that had other implications and preferred to avoid it and go with top_n. This is now in our rerank API that is shared with text-embedding models and cross-encoder models so changing it would not be great.
So today I have top_n for the API we expose but it becomes top_k as the familiar parameter name in CrossEncoder functions. Sorry. Is there a better thing to do here?
def smart_batching_collate_text_only( | ||
self, batch, truncate_input_tokens: Optional[int] = 0 | ||
): | ||
texts = [[] for _ in range(len(batch[0]))] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: this can be:
texts = [[]] * len(batch[0])
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
your way look better and "looks" equivalent but it breaks the data fetcher. Using range works.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
interesting. Whats "data fetcher" here ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is using torch DataLoader so callables and iterators or being used. Overkill if I wrote it from scratch, but I'm using what sentence-transformers/CrossEncoder has been using as much as possible (with our extensions as needed for truncation and token counting).
activation_fct=None, | ||
apply_softmax=False, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
activation_fct=None, | |
apply_softmax=False, | |
activation_fct = None, | |
apply_softmax = False, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nope. Our linter/formatter insists on no spaces around keyword parameter equals which is a good thing.
The odd thing is that lint/fmt rules are the opposite when there is a type.
Fortunately tox takes care of this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh wow 🤔 thats weird
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it is weird, but I find trying to make python a typed language is generally a little weird (usually not this odd)
pred_scores = torch.stack(pred_scores) | ||
elif convert_to_numpy: | ||
pred_scores = np.asarray( | ||
[score.cpu().detach().numpy() for score in pred_scores] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
wondering if
score.cpu().detach().numpy()
should be score.cpu().detach().float().item()
, since numpy()
can be an array but we want a float here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, your suggestion looks better. Done
* mostly removing unnecessary code * some better clarity Signed-off-by: Mark Sturdevant <[email protected]>
* The already borrowed errors are fixed with tokenizers per thread, so there were some misleading comments about not changing params for truncation (which we do for cross-encoder truncation). Signed-off-by: Mark Sturdevant <[email protected]>
Thanks for the reviews! Forgot to mention regarding the removal of ipex, etc code... Part of that I was keeping to get MPS support as well, but I've found out that the default with CrossEncoder handles MPS and CUDA device already. |
Default is 32. Can override with embedding batch_size in config or EMBEDDING_BATCH_SIZE env var. Signed-off-by: Mark Sturdevant <[email protected]>
* Moved the truncation check to a place that can determine the proper index for the error message (with batching). * Added test to validate some results after truncation. This is with a tiny model, but works for sanity. Signed-off-by: Mark Sturdevant <[email protected]>
The part that really tests that a token is truncated was wrong. * It was backwards and passing because the scores are sorted by rank * Using the index to get scores in the order of the inputs * Now correctly xx != xy but xy == xyz (truncated z) Signed-off-by: Mark Sturdevant <[email protected]>
Signed-off-by: Mark Sturdevant <[email protected]>
45acabb
to
8fa67cc
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM - thanks for the updates!
This module is closely related to EmbeddingModule.
Cross-encoder models use Q and A pairs and are trained return a relevance score for rank(). The existing rerank APIs in EmbeddingModule had to encode Q and A separately and use cosine similarity as a score. So the API is the same, but the results are supposed to be better (and slower).
Cross-encoder models do not support returning embedding vectors or sentence-similarity.
Support for the existing tokenization and model_info endpoints was also added.