Skip to content

Commit

Permalink
Fix ner entity merging (#596)
Browse files Browse the repository at this point in the history
  • Loading branch information
magdyksaleh authored Sep 6, 2024
1 parent b1578d2 commit 39b6c8e
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 77 deletions.
122 changes: 69 additions & 53 deletions router/src/infer.rs
Original file line number Diff line number Diff line change
@@ -1,19 +1,20 @@
/// Batching and inference logic
use crate::adapter::{extract_adapter_params, Adapter, BASE_MODEL_ADAPTER_ID};
use crate::batch::{ValidClassifyRequest, ValidEmbedRequest};
use crate::queue::AdapterEvent;
use crate::scheduler::AdapterScheduler;
use crate::validation::{Validation, ValidationError};
use crate::{
AdapterParameters, AlternativeToken, BatchClassifyRequest, BatchClassifyResponse,
ChatTemplateVersions, ClassifyRequest, ClassifyResponse, EmbedRequest, EmbedResponse, Entity,
Entry, HubTokenizerConfig, Message, TextMessage, Token, TokenizerConfigToken,
AdapterParameters, AlternativeToken, BatchClassifyRequest, ChatTemplateVersions,
ClassifyRequest, EmbedRequest, EmbedResponse, Entity, Entry, HubTokenizerConfig, Message,
TextMessage, Token, TokenizerConfigToken,
};
use crate::{GenerateRequest, PrefillToken};
use flume::r#async::RecvStream;
use flume::SendTimeoutError;
use futures::future::try_join_all;
use futures::stream::StreamExt;
/// Batching and inference logic
use itertools::izip;
use itertools::multizip;
use lorax_client::{
Batch, CachedBatch, ClassifyPredictionList, ClientError, Embedding, GeneratedText, Generation,
Expand All @@ -22,7 +23,6 @@ use lorax_client::{
use minijinja::{Environment, ErrorKind, Template};
use minijinja_contrib::pycompat;
use nohash_hasher::IntMap;
use regex::Regex;
use serde::{Deserialize, Serialize};
use std::collections::{HashMap, HashSet};
use std::sync::{
Expand Down Expand Up @@ -584,7 +584,7 @@ impl Infer {
pub(crate) async fn classify(
&self,
request: ClassifyRequest,
) -> Result<ClassifyResponse, InferError> {
) -> Result<Vec<Entity>, InferError> {
// Limit concurrent requests by acquiring a permit from the semaphore
let _permit = self
.clone()
Expand Down Expand Up @@ -613,7 +613,7 @@ impl Infer {
.await?;

let valid_request = ValidClassifyRequest {
inputs,
inputs: inputs.clone(),
tokenized_inputs,
input_length: input_length as u32,
adapter: adapter.clone(),
Expand Down Expand Up @@ -665,16 +665,18 @@ impl Infer {
queued: _,
id: _,
} => {
let entities = format_ner_output(predictions, self.tokenizer.clone().unwrap());
let entities = aggregate_ner_output_simple(
inputs.clone(),
predictions,
self.tokenizer.clone().unwrap(),
);
return_entities = Some(entities);
}
}
}

if let Some(return_entities) = return_entities {
Ok(ClassifyResponse {
entities: return_entities.into_iter().map(Entity::from).collect(),
})
Ok(return_entities.into_iter().map(Entity::from).collect())
} else {
let err = InferError::ClassificationFailure;
metrics::increment_counter!("lorax_request_failure", "err" => "classification_failure");
Expand All @@ -687,7 +689,7 @@ impl Infer {
pub(crate) async fn classify_batch(
&self,
request: BatchClassifyRequest,
) -> Result<BatchClassifyResponse, InferError> {
) -> Result<Vec<Vec<Entity>>, InferError> {
// Limit concurrent requests by acquiring a permit from the semaphore
let _permit = self
.clone()
Expand All @@ -712,6 +714,13 @@ impl Infer {
// MPSC channel to communicate with the background batching task
let (response_tx, response_rx) = flume::unbounded();

let request_id_map: HashMap<u64, String> = request
.inputs
.iter()
.enumerate()
.map(|(id, input)| (id as u64, input.clone()))
.collect();

for (id, r_inputs) in request.inputs.iter().enumerate() {
let inputs = r_inputs.to_string().clone();
let (tokenized_inputs, input_length) = self
Expand Down Expand Up @@ -757,8 +766,12 @@ impl Infer {
queued: _,
id,
} => {
let entities =
format_ner_output(predictions.clone(), self.tokenizer.clone().unwrap());
let request_inputs = request_id_map.get(&id.unwrap()).unwrap().clone();
let entities = aggregate_ner_output_simple(
request_inputs,
predictions.clone(),
self.tokenizer.clone().unwrap(),
);
all_entities.insert(id.unwrap(), entities);
}
_ => {
Expand All @@ -782,9 +795,7 @@ impl Infer {
.map(|(_, entities)| entities.into_iter().map(Entity::from).collect())
.collect();

Ok(BatchClassifyResponse {
entities: sorted_entities,
})
Ok(sorted_entities)
}
}

Expand Down Expand Up @@ -1478,67 +1489,72 @@ impl InferError {
}
}

fn format_ner_output(
fn get_tag(token_class: &str) -> (String, String) {
let parts: Vec<&str> = token_class.split('-').collect();
if parts.len() == 2 {
(parts[0].to_string(), parts[1].to_string())
} else {
("O".to_string(), "O".to_string())
}
}

fn aggregate_ner_output_simple(
input: String,
classify_prediction_list: ClassifyPredictionList,
tokenizer: Arc<Tokenizer>,
) -> Vec<Entity> {
let input_ids =
&classify_prediction_list.input_ids[1..classify_prediction_list.input_ids.len() - 1];
let predicted_token_class =
// Encode the input
let encoded = tokenizer.encode(input.clone(), false).unwrap();

let predicted_token_classes =
&classify_prediction_list.predictions[1..classify_prediction_list.predictions.len() - 1];
let scores = &classify_prediction_list.scores[1..classify_prediction_list.scores.len() - 1];

let tokens: Vec<String> = {
let re = Regex::new(r"\b\w+\b|\S").unwrap();
re.find_iter(&tokenizer.decode(input_ids, true).unwrap())
.map(|m| m.as_str().to_string())
.collect()
};

// Initialize result and tracking variables
let mut ner_results = Vec::new();
let mut current_entity: Option<Entity> = None;
for (i, ((token, token_class), score)) in tokens
.iter()
.zip(predicted_token_class.iter())
.zip(scores.iter())
.enumerate()
let mut entity_scores = Vec::new();

for (offset, token_class, score) in
izip!(encoded.get_offsets(), predicted_token_classes, scores)
{
if token_class != "O" {
let (bi, tag) = get_tag(token_class);
if bi == "B"
|| (current_entity.is_some() && tag != current_entity.as_ref().unwrap().entity)
|| (current_entity.is_some()
&& tag != current_entity.as_ref().unwrap().entity_group)
{
if let Some(entity) = current_entity {
if let Some(entity) = current_entity.take() {
ner_results.push(entity);
entity_scores.clear();
entity_scores.push(*score);
}
current_entity = Some(Entity {
entity: tag,
entity_group: tag,
score: *score,
index: i,
word: token.to_string(),
start: tokenizer.decode(&input_ids[..i], false).unwrap().len(),
end: tokenizer.decode(&input_ids[..=i], false).unwrap().len(),
word: "".to_string(), // stub for now. set later in second pass
start: offset.0,
end: offset.1,
});
} else if bi == "I" && current_entity.is_some() {
} else if current_entity.is_some() {
entity_scores.push(*score);
let entity = current_entity.as_mut().unwrap();
entity.word += &token.replace("##", "");
entity.end = tokenizer.decode(&input_ids[..=i], false).unwrap().len();
entity.score = entity_scores.iter().sum::<f32>() / entity_scores.len() as f32;
entity.end = offset.1;
}
} else if let Some(entity) = current_entity.take() {
ner_results.push(entity);
entity_scores.clear();
entity_scores.push(*score);
}
}
if let Some(entity) = current_entity {
if let Some(entity) = current_entity.take() {
ner_results.push(entity);
}
ner_results
}

fn get_tag(token_class: &str) -> (String, String) {
let parts: Vec<&str> = token_class.split('-').collect();
if parts.len() == 2 {
(parts[0].to_string(), parts[1].to_string())
} else {
("O".to_string(), "O".to_string())
let mut new_ner_results = Vec::with_capacity(ner_results.len());
for mut entity in ner_results {
entity.word = input[entity.start..entity.end].to_string();
new_ner_results.push(entity);
}
new_ner_results
}
16 changes: 2 additions & 14 deletions router/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -742,26 +742,15 @@ struct ClassifyRequest {
inputs: String,
}

#[derive(Debug, Serialize, Deserialize)]
struct ClassifyResponse {
entities: Vec<Entity>,
}

#[derive(Clone, Debug, Deserialize, ToSchema)]
struct BatchClassifyRequest {
inputs: Vec<String>,
}

#[derive(Debug, Serialize, Deserialize)]
struct BatchClassifyResponse {
entities: Vec<Vec<Entity>>,
}

#[derive(Debug, Serialize, Deserialize)]
struct Entity {
entity: String,
entity_group: String,
score: f32,
index: usize,
word: String,
start: usize,
end: usize,
Expand All @@ -770,9 +759,8 @@ struct Entity {
impl From<EntityMessage> for Entity {
fn from(entity: EntityMessage) -> Self {
Entity {
entity: entity.entity,
entity_group: entity.entity,
score: entity.score,
index: entity.index as usize,
word: entity.word,
start: entity.start as usize,
end: entity.end as usize,
Expand Down
20 changes: 10 additions & 10 deletions router/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,15 @@ use crate::infer::{InferError, InferResponse, InferStreamResponse};
use crate::validation::ValidationError;
use crate::{json, HubPreprocessorConfig, HubProcessorConfig, HubTokenizerConfig};
use crate::{
AdapterParameters, AlternativeToken, BatchClassifyRequest, BatchClassifyResponse,
BestOfSequence, ChatCompletionRequest, ChatCompletionResponse, ChatCompletionResponseChoice,
AdapterParameters, AlternativeToken, BatchClassifyRequest, BestOfSequence,
ChatCompletionRequest, ChatCompletionResponse, ChatCompletionResponseChoice,
ChatCompletionStreamResponse, ChatCompletionStreamResponseChoice, ChatMessage, ClassifyRequest,
ClassifyResponse, CompatGenerateRequest, CompletionFinishReason, CompletionRequest,
CompletionResponse, CompletionResponseChoice, CompletionResponseStreamChoice,
CompletionStreamResponse, Details, EmbedRequest, EmbedResponse, ErrorResponse, FinishReason,
GenerateParameters, GenerateRequest, GenerateResponse, HubModelInfo, Infer, Info, LogProbs,
PrefillToken, ResponseFormat, ResponseFormatType, SimpleToken, StreamDetails, StreamResponse,
Token, TokenizeRequest, TokenizeResponse, UsageInfo, Validation,
CompatGenerateRequest, CompletionFinishReason, CompletionRequest, CompletionResponse,
CompletionResponseChoice, CompletionResponseStreamChoice, CompletionStreamResponse, Details,
EmbedRequest, EmbedResponse, Entity, ErrorResponse, FinishReason, GenerateParameters,
GenerateRequest, GenerateResponse, HubModelInfo, Infer, Info, LogProbs, PrefillToken,
ResponseFormat, ResponseFormatType, SimpleToken, StreamDetails, StreamResponse, Token,
TokenizeRequest, TokenizeResponse, UsageInfo, Validation,
};
use axum::extract::Extension;
use axum::http::{HeaderMap, Method, StatusCode};
Expand Down Expand Up @@ -1469,7 +1469,7 @@ async fn embed(
async fn classify(
infer: Extension<Infer>,
Json(req): Json<ClassifyRequest>,
) -> Result<Json<ClassifyResponse>, (StatusCode, Json<ErrorResponse>)> {
) -> Result<Json<Vec<Entity>>, (StatusCode, Json<ErrorResponse>)> {
metrics::increment_counter!("lorax_request_count");
tracing::debug!("Input: {}", req.inputs);
let response = infer.classify(req).await?;
Expand All @@ -1490,7 +1490,7 @@ async fn classify(
async fn classify_batch(
infer: Extension<Infer>,
Json(req): Json<BatchClassifyRequest>,
) -> Result<Json<BatchClassifyResponse>, (StatusCode, Json<ErrorResponse>)> {
) -> Result<Json<Vec<Vec<Entity>>>, (StatusCode, Json<ErrorResponse>)> {
metrics::increment_counter!("lorax_request_count");
tracing::debug!("Inputs: {:?}", req.inputs);
let response = infer.classify_batch(req).await?;
Expand Down

0 comments on commit 39b6c8e

Please sign in to comment.