diff --git a/vaporetto/src/sentence.rs b/vaporetto/src/sentence.rs index 2e582d48..ccf961ac 100644 --- a/vaporetto/src/sentence.rs +++ b/vaporetto/src/sentence.rs @@ -283,16 +283,16 @@ impl Sentence { } // POS tag (false, '/') => { - if chars.is_empty() { + if chars.is_empty() || prev_boundary { return Err(VaporettoError::invalid_argument( "tokenized_text", - "must not start with a slash", + "a slash must follow a character", )); } - if prev_boundary { + if tag_str_tmp.is_some() { return Err(VaporettoError::invalid_argument( "tokenized_text", - "a slash must follow a character", + "invalid slash found", )); } tag_str_tmp.replace("".to_string()); @@ -377,7 +377,7 @@ impl Sentence { if tag_str.is_some() { return Err(VaporettoError::invalid_argument( "labeled_text", - "POS tag must be annotated to a token".to_string(), + "POS tag must be annotated to a token", )); } tags.push(None); @@ -390,7 +390,7 @@ impl Sentence { if !fixed_token && tag_str.is_some() { return Err(VaporettoError::invalid_argument( "labeled_text", - "POS tag must be annotated to a token".to_string(), + "POS tag must be annotated to a token", )); } tags.push(tag_str.take().map(Arc::new)); @@ -403,7 +403,7 @@ impl Sentence { if tag_str.is_some() { return Err(VaporettoError::invalid_argument( "labeled_text", - "POS tag must be annotated to a token".to_string(), + "POS tag must be annotated to a token", )); } tags.push(None); @@ -412,6 +412,12 @@ impl Sentence { } // POS tag '/' => { + if tag_str.is_some() { + return Err(VaporettoError::invalid_argument( + "labeled_text", + "invalid slash found", + )); + } tag_str.replace("".to_string()); } _ => { @@ -430,7 +436,7 @@ impl Sentence { if chars.len() != boundaries.len() + 1 { return Err(VaporettoError::invalid_argument( "labeled_text", - "invalid annotation".to_string(), + "invalid annotation", )); } @@ -750,11 +756,9 @@ impl Sentence { for (i, (c, b)) in chars_iter.zip(&self.boundaries).enumerate() { match b { BoundaryType::WordBoundary => { - if !self.tags.is_empty() { - if let Some(tag) = self.tags.get(i).and_then(|x| x.as_ref()) { - buffer.push('/'); - buffer.push_str(tag); - } + if let Some(tag) = self.tags.get(i).and_then(|x| x.as_ref()) { + buffer.push('/'); + buffer.push_str(tag); } buffer.push(' '); } @@ -812,51 +816,34 @@ impl Sentence { pub fn to_tokenized_vec(&self) -> Result> { let mut result = vec![]; let mut start = 0; - if self.tags.is_empty() { - for (i, b) in self.boundaries.iter().enumerate() { - match b { - BoundaryType::WordBoundary => { - let end = unsafe { *self.char_to_str_pos.get_unchecked(i + 1) }; - let surface = unsafe { self.text.get_unchecked(start..end) }; - result.push(Token { surface, tag: None }); - start = end; - } - BoundaryType::NotWordBoundary => (), - BoundaryType::Unknown => { - return Err(VaporettoError::invalid_sentence( - "contains an unknown boundary", - )); - } + for (i, b) in self.boundaries.iter().enumerate() { + match b { + BoundaryType::WordBoundary => { + let end = unsafe { *self.char_to_str_pos.get_unchecked(i + 1) }; + let surface = unsafe { self.text.get_unchecked(start..end) }; + let tag = self + .tags + .get(i) + .and_then(|x| x.as_ref()) + .map(|x| x.as_str()); + result.push(Token { surface, tag }); + start = end; } - } - let surface = unsafe { self.text.get_unchecked(start..) }; - result.push(Token { surface, tag: None }); - } else { - for (i, (b, tag)) in self.boundaries.iter().zip(&self.tags).enumerate() { - match b { - BoundaryType::WordBoundary => { - let end = unsafe { *self.char_to_str_pos.get_unchecked(i + 1) }; - let surface = unsafe { self.text.get_unchecked(start..end) }; - let tag = tag.as_ref().map(|x| x.as_str()); - result.push(Token { surface, tag }); - start = end; - } - BoundaryType::NotWordBoundary => (), - BoundaryType::Unknown => { - return Err(VaporettoError::invalid_sentence( - "contains an unknown boundary", - )); - } + BoundaryType::NotWordBoundary => (), + BoundaryType::Unknown => { + return Err(VaporettoError::invalid_sentence( + "contains an unknown boundary", + )); } } - let surface = unsafe { self.text.get_unchecked(start..) }; - let tag = self - .tags - .last() - .and_then(|x| x.as_ref()) - .map(|x| x.as_str()); - result.push(Token { surface, tag }); } + let surface = unsafe { self.text.get_unchecked(start..) }; + let tag = self + .tags + .last() + .and_then(|x| x.as_ref()) + .map(|x| x.as_str()); + result.push(Token { surface, tag }); Ok(result) } @@ -1719,6 +1706,18 @@ mod tests { assert_eq!(expected, s.unwrap()); } + #[test] + fn test_sentence_from_tokenized_with_tags_two_slashes() { + let s = Sentence::from_tokenized( + "Rust/名詞 で 良い/形容詞/動詞 プログラミング 体験 を !/補助記号", + ); + + assert_eq!( + "InvalidArgumentError: tokenized_text: invalid slash found", + &s.err().unwrap().to_string() + ); + } + #[test] fn test_sentence_from_tokenized_with_tags() { let s = @@ -1867,6 +1866,31 @@ mod tests { assert_eq!(expected, s); } + #[test] + fn test_sentence_update_tokenized_two_slashes() { + let mut s = Sentence::from_raw("12345").unwrap(); + let result = + s.update_tokenized("Rust/名詞 で 良い/形容詞/動詞 プログラミング 体験 を !/補助記号"); + + assert_eq!( + "InvalidArgumentError: tokenized_text: invalid slash found", + &result.err().unwrap().to_string() + ); + + let expected = Sentence { + text: " ".to_string(), + chars: vec![' '], + str_to_char_pos: vec![0, 1], + char_to_str_pos: vec![0, 1], + char_type: vec![Other as u8], + boundaries: vec![], + boundary_scores: vec![], + tag_scores: TagScores::default(), + tags: vec![None], + }; + assert_eq!(expected, s); + } + #[test] fn test_sentence_update_tokenized_with_tags() { let mut s = Sentence::from_raw("12345").unwrap();