From c0609991ebcf80b4ef72e44eb845283d9b4fb050 Mon Sep 17 00:00:00 2001 From: Alex Snaps Date: Tue, 30 Aug 2022 13:42:33 -0400 Subject: [PATCH] Add support for != in conditions --- limitador/src/limit.rs | 67 ++++++++++++++++++++++++++++++++---------- 1 file changed, 52 insertions(+), 15 deletions(-) diff --git a/limitador/src/limit.rs b/limitador/src/limit.rs index f3d470ed..45825ff0 100644 --- a/limitador/src/limit.rs +++ b/limitador/src/limit.rs @@ -113,15 +113,24 @@ impl TryFrom for Condition { &tokens[1].token_type, &tokens[2].token_type, ) { - (TokenType::Identifier, TokenType::EqualEqual, TokenType::String) => { + ( + TokenType::Identifier, + TokenType::EqualEqual | TokenType::NotEqual, + TokenType::String, + ) => { if let ( Some(Literal::Identifier(var_name)), Some(Literal::String(operand)), ) = (&tokens[0].literal, &tokens[2].literal) { + let predicate = match &tokens[1].token_type { + TokenType::EqualEqual => Predicate::Equal, + TokenType::NotEqual => Predicate::NotEqual, + _ => unreachable!(), + }; Ok(Condition { var_name: var_name.clone(), - predicate: Predicate::EQUAL, + predicate, operand: operand.clone(), }) } else { @@ -131,15 +140,24 @@ impl TryFrom for Condition { ) } } - (TokenType::String, TokenType::EqualEqual, TokenType::Identifier) => { + ( + TokenType::String, + TokenType::EqualEqual | TokenType::NotEqual, + TokenType::Identifier, + ) => { if let ( Some(Literal::String(operand)), Some(Literal::Identifier(var_name)), ) = (&tokens[0].literal, &tokens[2].literal) { + let predicate = match &tokens[1].token_type { + TokenType::EqualEqual => Predicate::Equal, + TokenType::NotEqual => Predicate::NotEqual, + _ => unreachable!(), + }; Ok(Condition { var_name: var_name.clone(), - predicate: Predicate::EQUAL, + predicate, operand: operand.clone(), }) } else { @@ -159,7 +177,7 @@ impl TryFrom for Condition { deprecated::deprecated_syntax_used(); Ok(Condition { var_name: var_name.clone(), - predicate: Predicate::EQUAL, + predicate: Predicate::Equal, operand: operand.clone(), }) } else { @@ -179,7 +197,7 @@ impl TryFrom for Condition { deprecated::deprecated_syntax_used(); Ok(Condition { var_name: var_name.clone(), - predicate: Predicate::EQUAL, + predicate: Predicate::Equal, operand: operand.to_string(), }) } else { @@ -193,7 +211,7 @@ impl TryFrom for Condition { let faulty = match (t1, t2) { ( TokenType::Identifier | TokenType::String, - TokenType::EqualEqual, + TokenType::EqualEqual | TokenType::NotEqual, ) => 2, (TokenType::Identifier | TokenType::String, _) => 1, (_, _) => 0, @@ -248,13 +266,15 @@ impl From for String { #[derive(PartialEq, Eq, Debug, Clone, Hash)] pub enum Predicate { - EQUAL, + Equal, + NotEqual, } impl Predicate { fn test(&self, lhs: &str, rhs: &str) -> bool { match self { - Predicate::EQUAL => lhs == rhs, + Predicate::Equal => lhs == rhs, + Predicate::NotEqual => lhs != rhs, } } } @@ -262,7 +282,8 @@ impl Predicate { impl From for String { fn from(op: Predicate) -> Self { match op { - Predicate::EQUAL => "==".to_string(), + Predicate::Equal => "==".to_string(), + Predicate::NotEqual => "!=".to_string(), } } } @@ -456,6 +477,7 @@ mod conditions { pub enum TokenType { // Predicates EqualEqual, + NotEqual, //Literals Identifier, @@ -491,6 +513,7 @@ mod conditions { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { match self.token_type { TokenType::EqualEqual => write!(f, "Equality (==)"), + TokenType::NotEqual => write!(f, "Unequal (!=)"), TokenType::Identifier => { write!(f, "Identifier: {}", self.literal.as_ref().unwrap()) } @@ -548,6 +571,20 @@ mod conditions { }) } } + '!' => { + if self.next_matches('=') { + Ok(Some(Token { + token_type: TokenType::NotEqual, + literal: None, + pos: self.pos - 1, + })) + } else { + Err(SyntaxError { + pos: self.pos, + error: ErrorType::InvalidCharacter(self.input[self.pos - 1]), + }) + } + } '"' | '\'' => self.scan_string(character).map(Some), ' ' | '\n' | '\r' | '\t' => Ok(None), _ => { @@ -911,7 +948,7 @@ mod tests { result, Condition { var_name: "x".to_string(), - predicate: Predicate::EQUAL, + predicate: Predicate::Equal, operand: "5".to_string(), } ); @@ -922,7 +959,7 @@ mod tests { result, Condition { var_name: "foobar".to_string(), - predicate: Predicate::EQUAL, + predicate: Predicate::Equal, operand: "ok".to_string(), } ); @@ -933,7 +970,7 @@ mod tests { result, Condition { var_name: "foobar".to_string(), - predicate: Predicate::EQUAL, + predicate: Predicate::Equal, operand: "ok".to_string(), } ); @@ -954,7 +991,7 @@ mod tests { .expect("should fail parsing"); assert_eq!( result.to_string(), - "SyntaxError: Invalid character `!` at offset 3 of condition \"x != 5 && x > 12\"" + "SyntaxError: Invalid character `&` at offset 8 of condition \"x != 5 && x > 12\"" .to_string() ); } @@ -963,7 +1000,7 @@ mod tests { fn condition_serialization() { let condition = Condition { var_name: "foobar".to_string(), - predicate: Predicate::EQUAL, + predicate: Predicate::Equal, operand: "ok".to_string(), }; let result = serde_json::to_string(&condition).expect("Should serialize");