From 6326ceec3f18fc030f19f5300df18f2a8265df3d Mon Sep 17 00:00:00 2001 From: David Tolnay Date: Thu, 4 May 2023 16:37:04 -0700 Subject: [PATCH] Don't panic in serde_test on running out of tokens --- serde_test/src/de.rs | 109 ++++++++++++++++++++----------------------- 1 file changed, 50 insertions(+), 59 deletions(-) diff --git a/serde_test/src/de.rs b/serde_test/src/de.rs index 39df2e8ca..7dc3adb09 100644 --- a/serde_test/src/de.rs +++ b/serde_test/src/de.rs @@ -33,10 +33,8 @@ fn unexpected(token: Token) -> Error { )) } -macro_rules! end_of_tokens { - () => { - panic!("ran out of tokens to deserialize") - }; +fn end_of_tokens() -> Error { + de::Error::custom("ran out of tokens to deserialize") } impl<'de> Deserializer<'de> { @@ -48,11 +46,8 @@ impl<'de> Deserializer<'de> { self.tokens.first().cloned() } - fn peek_token(&self) -> Token { - match self.peek_token_opt() { - Some(token) => token, - None => end_of_tokens!(), - } + fn peek_token(&self) -> Result { + self.peek_token_opt().ok_or_else(end_of_tokens) } pub fn next_token_opt(&mut self) -> Option { @@ -65,14 +60,10 @@ impl<'de> Deserializer<'de> { } } - fn next_token(&mut self) -> Token { - match self.tokens.split_first() { - Some((&first, rest)) => { - self.tokens = rest; - first - } - None => end_of_tokens!(), - } + fn next_token(&mut self) -> Result { + let (&first, rest) = self.tokens.split_first().ok_or_else(end_of_tokens)?; + self.tokens = rest; + Ok(first) } pub fn remaining(&self) -> usize { @@ -128,7 +119,7 @@ impl<'de, 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> { where V: Visitor<'de>, { - let token = self.next_token(); + let token = self.next_token()?; match token { Token::Bool(v) => visitor.visit_bool(v), Token::I8(v) => visitor.visit_i8(v), @@ -160,47 +151,47 @@ impl<'de, 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> { Token::Map { len } => self.visit_map(len, Token::MapEnd, visitor), Token::Struct { len, .. } => self.visit_map(Some(len), Token::StructEnd, visitor), Token::Enum { .. } => { - let variant = self.next_token(); - let next = self.peek_token(); + let variant = self.next_token()?; + let next = self.peek_token()?; match (variant, next) { (Token::Str(variant), Token::Unit) => { - self.next_token(); + self.next_token()?; visitor.visit_str(variant) } (Token::BorrowedStr(variant), Token::Unit) => { - self.next_token(); + self.next_token()?; visitor.visit_borrowed_str(variant) } (Token::String(variant), Token::Unit) => { - self.next_token(); + self.next_token()?; visitor.visit_string(variant.to_string()) } (Token::Bytes(variant), Token::Unit) => { - self.next_token(); + self.next_token()?; visitor.visit_bytes(variant) } (Token::BorrowedBytes(variant), Token::Unit) => { - self.next_token(); + self.next_token()?; visitor.visit_borrowed_bytes(variant) } (Token::ByteBuf(variant), Token::Unit) => { - self.next_token(); + self.next_token()?; visitor.visit_byte_buf(variant.to_vec()) } (Token::U8(variant), Token::Unit) => { - self.next_token(); + self.next_token()?; visitor.visit_u8(variant) } (Token::U16(variant), Token::Unit) => { - self.next_token(); + self.next_token()?; visitor.visit_u16(variant) } (Token::U32(variant), Token::Unit) => { - self.next_token(); + self.next_token()?; visitor.visit_u32(variant) } (Token::U64(variant), Token::Unit) => { - self.next_token(); + self.next_token()?; visitor.visit_u64(variant) } (variant, Token::Unit) => Err(unexpected(variant)), @@ -239,13 +230,13 @@ impl<'de, 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> { where V: Visitor<'de>, { - match self.peek_token() { + match self.peek_token()? { Token::Unit | Token::None => { - self.next_token(); + self.next_token()?; visitor.visit_none() } Token::Some => { - self.next_token(); + self.next_token()?; visitor.visit_some(self) } _ => self.deserialize_any(visitor), @@ -261,9 +252,9 @@ impl<'de, 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> { where V: Visitor<'de>, { - match self.peek_token() { + match self.peek_token()? { Token::Enum { name: n } if name == n => { - self.next_token(); + self.next_token()?; visitor.visit_enum(DeserializerEnumVisitor { de: self }) } @@ -283,7 +274,7 @@ impl<'de, 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> { where V: Visitor<'de>, { - match self.peek_token() { + match self.peek_token()? { Token::UnitStruct { .. } => { assert_next_token(self, Token::UnitStruct { name: name })?; visitor.visit_unit() @@ -300,7 +291,7 @@ impl<'de, 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> { where V: Visitor<'de>, { - match self.peek_token() { + match self.peek_token()? { Token::NewtypeStruct { .. } => { assert_next_token(self, Token::NewtypeStruct { name: name })?; visitor.visit_newtype_struct(self) @@ -313,21 +304,21 @@ impl<'de, 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> { where V: Visitor<'de>, { - match self.peek_token() { + match self.peek_token()? { Token::Unit | Token::UnitStruct { .. } => { - self.next_token(); + self.next_token()?; visitor.visit_unit() } Token::Seq { .. } => { - self.next_token(); + self.next_token()?; self.visit_seq(Some(len), Token::SeqEnd, visitor) } Token::Tuple { .. } => { - self.next_token(); + self.next_token()?; self.visit_seq(Some(len), Token::TupleEnd, visitor) } Token::TupleStruct { .. } => { - self.next_token(); + self.next_token()?; self.visit_seq(Some(len), Token::TupleStructEnd, visitor) } _ => self.deserialize_any(visitor), @@ -343,9 +334,9 @@ impl<'de, 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> { where V: Visitor<'de>, { - match self.peek_token() { + match self.peek_token()? { Token::Unit => { - self.next_token(); + self.next_token()?; visitor.visit_unit() } Token::UnitStruct { .. } => { @@ -353,11 +344,11 @@ impl<'de, 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> { visitor.visit_unit() } Token::Seq { .. } => { - self.next_token(); + self.next_token()?; self.visit_seq(Some(len), Token::SeqEnd, visitor) } Token::Tuple { .. } => { - self.next_token(); + self.next_token()?; self.visit_seq(Some(len), Token::TupleEnd, visitor) } Token::TupleStruct { len: n, .. } => { @@ -377,13 +368,13 @@ impl<'de, 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> { where V: Visitor<'de>, { - match self.peek_token() { + match self.peek_token()? { Token::Struct { len: n, .. } => { assert_next_token(self, Token::Struct { name: name, len: n })?; self.visit_map(Some(fields.len()), Token::StructEnd, visitor) } Token::Map { .. } => { - self.next_token(); + self.next_token()?; self.visit_map(Some(fields.len()), Token::MapEnd, visitor) } _ => self.deserialize_any(visitor), @@ -473,7 +464,7 @@ impl<'de, 'a> EnumAccess<'de> for DeserializerEnumVisitor<'a, 'de> { where V: DeserializeSeed<'de>, { - match self.de.peek_token() { + match self.de.peek_token()? { Token::UnitVariant { variant: v, .. } | Token::NewtypeVariant { variant: v, .. } | Token::TupleVariant { variant: v, .. } @@ -494,9 +485,9 @@ impl<'de, 'a> VariantAccess<'de> for DeserializerEnumVisitor<'a, 'de> { type Error = Error; fn unit_variant(self) -> Result<(), Error> { - match self.de.peek_token() { + match self.de.peek_token()? { Token::UnitVariant { .. } => { - self.de.next_token(); + self.de.next_token()?; Ok(()) } _ => Deserialize::deserialize(self.de), @@ -507,9 +498,9 @@ impl<'de, 'a> VariantAccess<'de> for DeserializerEnumVisitor<'a, 'de> { where T: DeserializeSeed<'de>, { - match self.de.peek_token() { + match self.de.peek_token()? { Token::NewtypeVariant { .. } => { - self.de.next_token(); + self.de.next_token()?; seed.deserialize(self.de) } _ => seed.deserialize(self.de), @@ -520,9 +511,9 @@ impl<'de, 'a> VariantAccess<'de> for DeserializerEnumVisitor<'a, 'de> { where V: Visitor<'de>, { - match self.de.peek_token() { + match self.de.peek_token()? { Token::TupleVariant { len: enum_len, .. } => { - let token = self.de.next_token(); + let token = self.de.next_token()?; if len == enum_len { self.de @@ -534,7 +525,7 @@ impl<'de, 'a> VariantAccess<'de> for DeserializerEnumVisitor<'a, 'de> { Token::Seq { len: Some(enum_len), } => { - let token = self.de.next_token(); + let token = self.de.next_token()?; if len == enum_len { self.de.visit_seq(Some(len), Token::SeqEnd, visitor) @@ -554,9 +545,9 @@ impl<'de, 'a> VariantAccess<'de> for DeserializerEnumVisitor<'a, 'de> { where V: Visitor<'de>, { - match self.de.peek_token() { + match self.de.peek_token()? { Token::StructVariant { len: enum_len, .. } => { - let token = self.de.next_token(); + let token = self.de.next_token()?; if fields.len() == enum_len { self.de @@ -568,7 +559,7 @@ impl<'de, 'a> VariantAccess<'de> for DeserializerEnumVisitor<'a, 'de> { Token::Map { len: Some(enum_len), } => { - let token = self.de.next_token(); + let token = self.de.next_token()?; if fields.len() == enum_len { self.de