diff --git a/libs/prisma-value/src/lib.rs b/libs/prisma-value/src/lib.rs index eca96410b3c4..f7c3b4c2333d 100644 --- a/libs/prisma-value/src/lib.rs +++ b/libs/prisma-value/src/lib.rs @@ -13,7 +13,7 @@ pub use error::ConversionFailure; pub type PrismaValueResult = std::result::Result; pub type PrismaListValue = Vec; -#[derive(Debug, PartialEq, Clone, Eq, Hash, Serialize, Deserialize, PartialOrd, Ord)] +#[derive(Debug, Clone, Eq, PartialEq, Hash, Serialize, Deserialize, PartialOrd, Ord)] #[serde(untagged)] pub enum PrismaValue { String(String), @@ -46,12 +46,18 @@ pub enum PrismaValue { /// Stringify a date to the following format /// 1999-05-01T00:00:00.000Z -pub fn stringify_date(date: &DateTime) -> String { +pub fn stringify_datetime(datetime: &DateTime) -> String { // Warning: Be careful if you plan on changing the code below // The findUnique batch optimization expects date inputs to have exactly the same format as date outputs // This works today because clients always send date inputs in the same format as the serialized format below // Updating this without transforming date inputs to the same format WILL break the findUnique batch optimization - date.to_rfc3339_opts(SecondsFormat::Millis, true) + datetime.to_rfc3339_opts(SecondsFormat::Millis, true) +} + +/// Parses an RFC 3339 and ISO 8601 date and time string such as 1996-12-19T16:39:57-08:00, +/// then returns a new DateTime with a parsed FixedOffset. +pub fn parse_datetime(datetime: &str) -> chrono::ParseResult> { + DateTime::parse_from_rfc3339(datetime) } pub fn encode_bytes(bytes: &[u8]) -> String { @@ -135,7 +141,7 @@ fn serialize_date(date: &DateTime, serializer: S) -> Result(bytes: &[u8], serializer: S) -> Result @@ -258,12 +264,19 @@ impl PrismaValue { } } + pub fn into_object(self) -> Option> { + match self { + PrismaValue::Object(obj) => Some(obj), + _ => None, + } + } + pub fn new_float(float: f64) -> PrismaValue { PrismaValue::Float(BigDecimal::from_f64(float).unwrap()) } pub fn new_datetime(datetime: &str) -> PrismaValue { - PrismaValue::DateTime(DateTime::parse_from_rfc3339(datetime).unwrap()) + PrismaValue::DateTime(parse_datetime(datetime).unwrap()) } pub fn as_boolean(&self) -> Option<&bool> { diff --git a/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/filters/composite/equals.rs b/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/filters/composite/equals.rs index 6fe204b8970c..50fb82208e14 100644 --- a/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/filters/composite/equals.rs +++ b/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/filters/composite/equals.rs @@ -221,7 +221,7 @@ mod to_many { } }"#, 2009, - "Query parsing/validation error at `Query.findManyTestModel.where.TestModelWhereInput.to_many_as.CompositeACompositeListFilter.equals`: Value types mismatch. Have: Object({\"a_1\": String(\"Test\"), \"a_2\": Int(0)}), want: Object(CompositeAObjectEqualityInput)" + "`Query.findManyTestModel.where.TestModelWhereInput.to_many_as.CompositeACompositeListFilter.equals`: Value types mismatch. Have: Object([(\"a_1\", String(\"Test\")), (\"a_2\", Int(0))]), want: Object(CompositeAObjectEqualityInput)" ); Ok(()) diff --git a/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/filters/json.rs b/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/filters/json.rs index 806b2e57e5e6..3ea8cede3d3e 100644 --- a/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/filters/json.rs +++ b/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/filters/json.rs @@ -216,14 +216,14 @@ mod json { &runner, r#"query { findManyTestModel(where: { json: { not: { equals: "{}" }}}) { id }}"#, 2009, - "`Query.findManyTestModel.where.TestModelWhereInput.json.JsonNullableFilter.not`: Value types mismatch. Have: Object({\"equals\": String(\"{}\")}), want: Json" + "`Query.findManyTestModel.where.TestModelWhereInput.json.JsonNullableFilter.not`: Value types mismatch. Have: Object([(\"equals\", String(\"{}\"))]), want: Json" ); assert_error!( &runner, r#"query { findManyTestModel(where: { json: { not: { equals: null }}}) { id }}"#, 2009, - "`Query.findManyTestModel.where.TestModelWhereInput.json.JsonNullableFilter.not`: Value types mismatch. Have: Object({\"equals\": Null}), want: Json" + "`Query.findManyTestModel.where.TestModelWhereInput.json.JsonNullableFilter.not`: Value types mismatch. Have: Object([(\"equals\", Null)]), want: Json" ); Ok(()) diff --git a/query-engine/connector-test-kit-rs/query-engine-tests/tests/writes/composites/list.rs b/query-engine/connector-test-kit-rs/query-engine-tests/tests/writes/composites/list.rs index dfbab4fa33b5..5a1590926f1d 100644 --- a/query-engine/connector-test-kit-rs/query-engine-tests/tests/writes/composites/list.rs +++ b/query-engine/connector-test-kit-rs/query-engine-tests/tests/writes/composites/list.rs @@ -414,7 +414,7 @@ mod update { runner, query, 2009, - "`Mutation.updateOneTestModel.data.TestModelUpdateInput.to_many_as.CompositeAListUpdateEnvelopeInput.set.CompositeACreateInput.a_2`: Value types mismatch. Have: Object({\"update\": Object({\"increment\": Int(3)})}), want: Int" + "`Mutation.updateOneTestModel.data.TestModelUpdateInput.to_many_as.CompositeAListUpdateEnvelopeInput.set.CompositeACreateInput.a_2`: Value types mismatch. Have: Object([(\"update\", Object([(\"increment\", Int(3))]))]), want: Int" ); // Ensure `update` cannot be used in the Unchecked type @@ -422,7 +422,7 @@ mod update { runner, query, 2009, - "`Mutation.updateOneTestModel.data.TestModelUpdateInput.to_many_as.CompositeAListUpdateEnvelopeInput.set.CompositeACreateInput.a_2`: Value types mismatch. Have: Object({\"update\": Object({\"increment\": Int(3)})}), want: Int" + "`Mutation.updateOneTestModel.data.TestModelUpdateInput.to_many_as.CompositeAListUpdateEnvelopeInput.set.CompositeACreateInput.a_2`: Value types mismatch. Have: Object([(\"update\", Object([(\"increment\", Int(3))]))]), want: Int" ); Ok(()) diff --git a/query-engine/core/src/query_document/error.rs b/query-engine/core/src/query_document/error.rs index 243d1e6d7a47..5b6fb41ef80a 100644 --- a/query-engine/core/src/query_document/error.rs +++ b/query-engine/core/src/query_document/error.rs @@ -1,4 +1,4 @@ -use crate::{query_document::QueryValue, schema::InputType}; +use crate::{query_document::PrismaValue, schema::InputType}; use fmt::Display; use itertools::Itertools; use std::fmt; @@ -63,7 +63,7 @@ pub enum QueryParserErrorKind { ArgumentNotFoundError, FieldCountError(FieldCountError), ValueParseError(String), - ValueTypeMismatchError { have: QueryValue, want: InputType }, + ValueTypeMismatchError { have: PrismaValue, want: InputType }, InputUnionParseError { parsing_errors: Vec }, ValueFitError(String), } diff --git a/query-engine/core/src/query_document/mod.rs b/query-engine/core/src/query_document/mod.rs index b91ee1cbcaf1..36719469f02e 100644 --- a/query-engine/core/src/query_document/mod.rs +++ b/query-engine/core/src/query_document/mod.rs @@ -18,7 +18,6 @@ mod error; mod operation; mod parse_ast; mod parser; -mod query_value; mod selection; mod transformers; @@ -26,13 +25,11 @@ pub use error::*; pub use operation::*; pub use parse_ast::*; pub use parser::*; -pub use query_value::*; pub use selection::*; pub use transformers::*; use crate::resolve_compound_field; -use indexmap::IndexMap; -use prisma_models::ModelRef; +use prisma_models::{ModelRef, PrismaValue}; use schema::QuerySchemaRef; use schema_builder::constants::*; use std::collections::HashMap; @@ -82,12 +79,12 @@ impl BatchDocument { where_obj.iter().any(|(key, val)| match val { // If it's a compound, then it's still considered as scalar - QueryValue::Object(_) if resolve_compound_field(key, model).is_some() => false, + PrismaValue::Object(_) if resolve_compound_field(key, model).is_some() => false, // Otherwise, we just look for a scalar field inside the model. If it's not one, then we break. val => match model.fields().find_from_scalar(&key) { Ok(_) => match val { // Consider scalar _only_ if the filter object contains "equals". eg: `{ scalar_field: { equals: 1 } }` - QueryValue::Object(obj) => !obj.contains_key(filters::EQUALS), + PrismaValue::Object(obj) => !obj.iter().any(|(k, _)| k.as_str() == filters::EQUALS), _ => false, }, Err(_) => true, @@ -158,7 +155,7 @@ impl BatchDocumentTransaction { #[derive(Debug, Clone)] pub struct CompactedDocument { - pub arguments: Vec>, + pub arguments: Vec>, pub nested_selection: Vec, pub operation: Operation, pub keys: Vec, @@ -246,11 +243,11 @@ impl CompactedDocument { // Convert the selections into a map of arguments. This defines the // response order and how we fetch the right data from the response set. - let arguments: Vec> = selections + let arguments: Vec> = selections .into_iter() .map(|mut sel| { let where_obj = sel.pop_argument().unwrap().1.into_object().unwrap(); - let filter_map: HashMap = extract_filter(where_obj, model).into_iter().collect(); + let filter_map: HashMap = extract_filter(where_obj, model).into_iter().collect(); filter_map }) @@ -260,7 +257,7 @@ impl CompactedDocument { let keys: Vec<_> = arguments[0] .iter() .flat_map(|pair| match pair { - (_, QueryValue::Object(obj)) => obj.keys().map(ToOwned::to_owned).collect(), + (_, PrismaValue::Object(obj)) => obj.iter().map(|(key, _)| key.to_owned()).collect(), (key, _) => vec![key.to_owned()], }) .collect(); @@ -284,16 +281,19 @@ impl CompactedDocument { /// Furthermore, this list is used to match the results of the findMany query back to the original findUnique queries. /// Consequently, we only extract EQUALS filters or else we would have to manually implement other filters. /// This is a limitation that _could_ technically be lifted but that's not worth it for now. -fn extract_filter(where_obj: IndexMap, model: &ModelRef) -> Vec<(String, QueryValue)> { +fn extract_filter(where_obj: Vec, model: &ModelRef) -> Vec { where_obj .into_iter() .flat_map(|(key, val)| match val { // This means our query has a compound field in the form of: {co1_col2: { col1_col2: { col1: , col2: } }} - QueryValue::Object(obj) if resolve_compound_field(&key, model).is_some() => obj.into_iter().collect(), + PrismaValue::Object(obj) if resolve_compound_field(&key, model).is_some() => obj.into_iter().collect(), // This means our query has a scalar filter in the form of {col1: { equals: }} - QueryValue::Object(obj) => { + PrismaValue::Object(obj) => { // This is safe because it's been validated before in the `.can_compact` method. - let equal_val = obj.get(filters::EQUALS).expect("we only support scalar equals filters"); + let (_, equal_val) = obj + .iter() + .find(|(k, _)| k == filters::EQUALS) + .expect("we only support scalar equals filters"); vec![(key, equal_val.clone())] } diff --git a/query-engine/core/src/query_document/parser.rs b/query-engine/core/src/query_document/parser.rs index 942462c363b4..03ce4d4f6fb3 100644 --- a/query-engine/core/src/query_document/parser.rs +++ b/query-engine/core/src/query_document/parser.rs @@ -2,14 +2,11 @@ use super::*; use crate::schema::*; use bigdecimal::{BigDecimal, ToPrimitive}; use chrono::prelude::*; -use indexmap::IndexMap; use prisma_value::PrismaValue; use psl::dml::ValueGeneratorFn; -use std::{borrow::Borrow, collections::HashSet, convert::TryFrom, str::FromStr, sync::Arc}; +use std::{borrow::Borrow, collections::HashSet, convert::TryFrom, hash::Hash, str::FromStr, sync::Arc}; use uuid::Uuid; -// todo: validate is one of! - pub struct QueryDocumentParser { /// NOW() default value that's reused for all NOW() defaults on a single query default_now: PrismaValue, @@ -96,7 +93,7 @@ impl QueryDocumentParser { &self, parent_path: QueryPath, schema_field: &OutputFieldRef, - given_arguments: &[(String, QueryValue)], + given_arguments: &[(String, PrismaValue)], ) -> QueryParserResult> { let left: HashSet<&str> = schema_field.arguments.iter().map(|arg| arg.name.as_str()).collect(); let right: HashSet<&str> = given_arguments.iter().map(|arg| arg.0.as_str()).collect(); @@ -119,7 +116,7 @@ impl QueryDocumentParser { .iter() .filter_map(|schema_input_arg| { // Match schema argument field to an argument field in the incoming document. - let selection_arg: Option<(String, QueryValue)> = given_arguments + let selection_arg: Option<(String, PrismaValue)> = given_arguments .iter() .find(|given_argument| given_argument.0 == schema_input_arg.name) .cloned(); @@ -149,12 +146,12 @@ impl QueryDocumentParser { .collect() } - /// Parses and validates a QueryValue against possible input types. + /// Parses and validates a PrismaValue against possible input types. /// Matching is done in order of definition on the input type. First matching type wins. fn parse_input_value( &self, parent_path: QueryPath, - value: QueryValue, + value: PrismaValue, possible_input_types: &[InputType], ) -> QueryParserResult { let mut parse_results = vec![]; @@ -163,10 +160,10 @@ impl QueryDocumentParser { let value = value.clone(); let result = match (&value, input_type) { // Null handling - (QueryValue::Null, InputType::Scalar(ScalarType::Null)) => { + (PrismaValue::Null, InputType::Scalar(ScalarType::Null)) => { Ok(ParsedInputValue::Single(PrismaValue::Null)) } - (QueryValue::Null, _) => Err(QueryParserError { + (PrismaValue::Null, _) => Err(QueryParserError { path: parent_path.clone(), error_kind: QueryParserErrorKind::RequiredValueNotSetError, }), @@ -177,17 +174,17 @@ impl QueryDocumentParser { .map(ParsedInputValue::Single), // Enum handling - (QueryValue::Enum(_), InputType::Enum(et)) => self.parse_enum(&parent_path, value, &et.into_arc()), - (QueryValue::String(_), InputType::Enum(et)) => self.parse_enum(&parent_path, value, &et.into_arc()), - (QueryValue::Boolean(_), InputType::Enum(et)) => self.parse_enum(&parent_path, value, &et.into_arc()), + (PrismaValue::Enum(_), InputType::Enum(et)) => self.parse_enum(&parent_path, value, &et.into_arc()), + (PrismaValue::String(_), InputType::Enum(et)) => self.parse_enum(&parent_path, value, &et.into_arc()), + (PrismaValue::Boolean(_), InputType::Enum(et)) => self.parse_enum(&parent_path, value, &et.into_arc()), // List handling. - (QueryValue::List(values), InputType::List(l)) => self + (PrismaValue::List(values), InputType::List(l)) => self .parse_list(&parent_path, values.clone(), &l) .map(ParsedInputValue::List), // Object handling - (QueryValue::Object(o), InputType::Object(obj)) => self + (PrismaValue::Object(o), InputType::Object(obj)) => self .parse_input_object(parent_path.clone(), o.clone(), obj.into_arc()) .map(ParsedInputValue::Map), @@ -231,55 +228,70 @@ impl QueryDocumentParser { fn parse_scalar( &self, parent_path: &QueryPath, - value: QueryValue, + value: PrismaValue, scalar_type: &ScalarType, ) -> QueryParserResult { match (value, scalar_type.clone()) { - (QueryValue::String(s), ScalarType::String) => Ok(PrismaValue::String(s)), - (QueryValue::String(s), ScalarType::Xml) => Ok(PrismaValue::Xml(s)), - (QueryValue::String(s), ScalarType::JsonList) => self.parse_json_list(parent_path, &s), - (QueryValue::String(s), ScalarType::Bytes) => self.parse_bytes(parent_path, s), - (QueryValue::String(s), ScalarType::Decimal) => self.parse_decimal(parent_path, s), - (QueryValue::String(s), ScalarType::BigInt) => self.parse_bigint(parent_path, s), - (QueryValue::String(s), ScalarType::UUID) => { + // Identity matchers + (PrismaValue::String(s), ScalarType::String) => Ok(PrismaValue::String(s)), + (PrismaValue::Boolean(b), ScalarType::Boolean) => Ok(PrismaValue::Boolean(b)), + (PrismaValue::Json(json), ScalarType::Json) => Ok(PrismaValue::Json(json)), + (PrismaValue::Xml(xml), ScalarType::Xml) => Ok(PrismaValue::Xml(xml)), + (PrismaValue::Uuid(uuid), ScalarType::UUID) => Ok(PrismaValue::Uuid(uuid)), + (PrismaValue::Bytes(bytes), ScalarType::Bytes) => Ok(PrismaValue::Bytes(bytes)), + (PrismaValue::BigInt(b_int), ScalarType::BigInt) => Ok(PrismaValue::BigInt(b_int)), + (PrismaValue::DateTime(s), ScalarType::DateTime) => Ok(PrismaValue::DateTime(s)), + (PrismaValue::Null, ScalarType::Null) => Ok(PrismaValue::Null), + + // String coercion matchers + (PrismaValue::String(s), ScalarType::Xml) => Ok(PrismaValue::Xml(s)), + (PrismaValue::String(s), ScalarType::JsonList) => self.parse_json_list(parent_path, &s), + (PrismaValue::String(s), ScalarType::Bytes) => self.parse_bytes(parent_path, s), + (PrismaValue::String(s), ScalarType::Decimal) => self.parse_decimal(parent_path, s), + (PrismaValue::String(s), ScalarType::BigInt) => self.parse_bigint(parent_path, s), + (PrismaValue::String(s), ScalarType::UUID) => { self.parse_uuid(parent_path, s.as_str()).map(PrismaValue::Uuid) } - (QueryValue::String(s), ScalarType::Json) => { + (PrismaValue::String(s), ScalarType::Json) => { Ok(PrismaValue::Json(self.parse_json(parent_path, &s).map(|_| s)?)) } - (QueryValue::String(s), ScalarType::DateTime) => { + (PrismaValue::String(s), ScalarType::DateTime) => { self.parse_datetime(parent_path, s.as_str()).map(PrismaValue::DateTime) } - (QueryValue::DateTime(s), ScalarType::DateTime) => Ok(PrismaValue::DateTime(s)), - (QueryValue::Int(i), ScalarType::Int) => Ok(PrismaValue::Int(i)), - (QueryValue::Int(i), ScalarType::Float) => Ok(PrismaValue::Float(BigDecimal::from(i))), - (QueryValue::Int(i), ScalarType::Decimal) => Ok(PrismaValue::Float(BigDecimal::from(i))), - (QueryValue::Int(i), ScalarType::BigInt) => Ok(PrismaValue::BigInt(i)), + // Int coercion matchers + (PrismaValue::Int(i), ScalarType::Int) => Ok(PrismaValue::Int(i)), + (PrismaValue::Int(i), ScalarType::Float) => Ok(PrismaValue::Float(BigDecimal::from(i))), + (PrismaValue::Int(i), ScalarType::Decimal) => Ok(PrismaValue::Float(BigDecimal::from(i))), + (PrismaValue::Int(i), ScalarType::BigInt) => Ok(PrismaValue::BigInt(i)), - (QueryValue::Float(f), ScalarType::Float) => Ok(PrismaValue::Float(f)), - (QueryValue::Float(f), ScalarType::Decimal) => Ok(PrismaValue::Float(f)), - (QueryValue::Float(f), ScalarType::Int) => match f.to_i64() { + // Float coercion matchers + (PrismaValue::Float(f), ScalarType::Float) => Ok(PrismaValue::Float(f)), + (PrismaValue::Float(f), ScalarType::Decimal) => Ok(PrismaValue::Float(f)), + (PrismaValue::Float(f), ScalarType::Int) => match f.to_i64() { Some(converted) => Ok(PrismaValue::Int(converted)), None => Err(QueryParserError::new(parent_path.clone(), QueryParserErrorKind::ValueFitError( format!("Unable to fit float value (or large JS integer serialized in exponent notation) '{}' into a 64 Bit signed integer for field '{}'. If you're trying to store large integers, consider using `BigInt`.", f, parent_path.last().unwrap())))), }, - (QueryValue::Boolean(b), ScalarType::Boolean) => Ok(PrismaValue::Boolean(b)), + // UUID coercion matchers + (PrismaValue::Uuid(uuid), ScalarType::String) => Ok(PrismaValue::String(uuid.to_string())), // All other combinations are value type mismatches. - (qv, _) => Err(QueryParserError { - path: parent_path.clone(), - error_kind: QueryParserErrorKind::ValueTypeMismatchError { - have: qv, - want: InputType::Scalar(scalar_type.clone()), - }, - }), + (qv, _) => { + Err(QueryParserError { + path: parent_path.clone(), + error_kind: QueryParserErrorKind::ValueTypeMismatchError { + have: qv, + want: InputType::Scalar(scalar_type.clone()), + }, + }) + }, } } fn parse_datetime(&self, path: &QueryPath, s: &str) -> QueryParserResult> { - DateTime::parse_from_rfc3339(s).map_err(|err| QueryParserError { + prisma_value::parse_datetime(s).map_err(|err| QueryParserError { path: path.clone(), error_kind: QueryParserErrorKind::ValueParseError(format!( "Invalid DateTime: '{}' (must be ISO 8601 compatible). Underlying error: {}", @@ -356,7 +368,7 @@ impl QueryDocumentParser { fn parse_list( &self, path: &QueryPath, - values: Vec, + values: Vec, value_type: &InputType, ) -> QueryParserResult> { values @@ -365,11 +377,11 @@ impl QueryDocumentParser { .collect::>>() } - fn parse_enum(&self, path: &QueryPath, val: QueryValue, typ: &EnumTypeRef) -> QueryParserResult { + fn parse_enum(&self, path: &QueryPath, val: PrismaValue, typ: &EnumTypeRef) -> QueryParserResult { let raw = match val { - QueryValue::Enum(s) => s, - QueryValue::String(s) => s, - QueryValue::Boolean(b) => if b { "true" } else { "false" }.to_owned(), // Case where a bool was misinterpreted as constant literal + PrismaValue::Enum(s) => s, + PrismaValue::String(s) => s, + PrismaValue::Boolean(b) => if b { "true" } else { "false" }.to_owned(), // Case where a bool was misinterpreted as constant literal _ => { return Err(QueryParserError { path: path.clone(), @@ -412,7 +424,7 @@ impl QueryDocumentParser { fn parse_input_object( &self, parent_path: QueryPath, - object: IndexMap, + object: Vec<(String, PrismaValue)>, schema_object: InputObjectTypeStrongRef, ) -> QueryParserResult { let path = parent_path.add(schema_object.identifier.name().to_owned()); @@ -422,7 +434,7 @@ impl QueryDocumentParser { .map(|field| field.name.as_str()) .collect(); - let right: HashSet<&str> = object.keys().map(|k| k.as_str()).collect(); + let right: HashSet<&str> = object.iter().map(|(k, _)| k.as_str()).collect(); let diff = Diff::new(&left, &right); // First, check that all fields **not** provided in the query (left diff) are optional, @@ -439,16 +451,16 @@ impl QueryDocumentParser { // If it's not optional and has no default, a required field has not been provided. match &field.default_value { Some(default_value) => { - let query_value = match &default_value.kind { + let default_pv = match &default_value.kind { psl::dml::DefaultKind::Expression(ref expr) if matches!(expr.generator(), ValueGeneratorFn::Now) => { - self.default_now.clone().into() + self.default_now.clone() } - _ => default_value.get()?.into(), + _ => default_value.get()?, }; - match self.parse_input_value(path, query_value, &field.field_types) { + match self.parse_input_value(path, default_pv, &field.field_types) { Ok(value) => Some(Ok((field.name.clone(), value))), Err(err) => Some(Err(err)), } @@ -531,13 +543,13 @@ impl QueryDocumentParser { } #[derive(Debug)] -struct Diff<'a, T: std::cmp::Eq + std::hash::Hash> { +struct Diff<'a, T: Eq + Hash> { pub left: Vec<&'a T>, pub right: Vec<&'a T>, pub _equal: Vec<&'a T>, } -impl<'a, T: std::cmp::Eq + std::hash::Hash> Diff<'a, T> { +impl<'a, T: Eq + Hash> Diff<'a, T> { fn new(left_side: &'a HashSet, right_side: &'a HashSet) -> Diff<'a, T> { let left: Vec<&T> = left_side.difference(right_side).collect(); let right: Vec<&T> = right_side.difference(left_side).collect(); diff --git a/query-engine/core/src/query_document/query_value.rs b/query-engine/core/src/query_document/query_value.rs deleted file mode 100644 index 80ddf671e317..000000000000 --- a/query-engine/core/src/query_document/query_value.rs +++ /dev/null @@ -1,68 +0,0 @@ -use bigdecimal::BigDecimal; -use indexmap::IndexMap; -use prisma_value::{stringify_date, PrismaValue}; - -#[derive(Debug, Clone, Eq)] -pub enum QueryValue { - Int(i64), - Float(BigDecimal), - String(String), - Boolean(bool), - Null, - Enum(String), - List(Vec), - Object(IndexMap), - DateTime(chrono::DateTime), -} - -impl PartialEq for QueryValue { - fn eq(&self, other: &Self) -> bool { - match (self, other) { - (QueryValue::Int(n1), QueryValue::Int(n2)) => n1 == n2, - (QueryValue::Float(n1), QueryValue::Float(n2)) => n1 == n2, - (QueryValue::String(s1), QueryValue::String(s2)) => s1 == s2, - (QueryValue::Boolean(b1), QueryValue::Boolean(b2)) => b1 == b2, - (QueryValue::Null, QueryValue::Null) => true, - (QueryValue::Enum(kind1), QueryValue::Enum(kind2)) => kind1 == kind2, - (QueryValue::List(list1), QueryValue::List(list2)) => list1 == list2, - (QueryValue::Object(t1), QueryValue::Object(t2)) => t1 == t2, - (QueryValue::DateTime(t1), QueryValue::DateTime(t2)) => t1 == t2, - (QueryValue::String(t1), QueryValue::DateTime(t2)) | (QueryValue::DateTime(t2), QueryValue::String(t1)) => { - chrono::DateTime::parse_from_rfc3339(t1) - .map(|t1| &t1 == t2) - .unwrap_or_else(|_| t1 == stringify_date(t2).as_str()) - } - _ => false, - } - } -} - -impl QueryValue { - pub fn into_object(self) -> Option> { - match self { - Self::Object(map) => Some(map), - _ => None, - } - } -} - -impl From for QueryValue { - fn from(pv: PrismaValue) -> Self { - match pv { - PrismaValue::String(s) => Self::String(s), - PrismaValue::Float(f) => Self::Float(f), - PrismaValue::Boolean(b) => Self::Boolean(b), - PrismaValue::DateTime(dt) => Self::DateTime(dt), - PrismaValue::Enum(s) => Self::Enum(s), - PrismaValue::List(l) => Self::List(l.into_iter().map(QueryValue::from).collect()), - PrismaValue::Int(i) => Self::Int(i), - PrismaValue::Null => Self::Null, - PrismaValue::Uuid(u) => Self::String(u.hyphenated().to_string()), - PrismaValue::Json(s) => Self::String(s), - PrismaValue::Xml(s) => Self::String(s), - PrismaValue::Bytes(b) => Self::String(prisma_value::encode_bytes(&b)), - PrismaValue::BigInt(i) => Self::Int(i), - PrismaValue::Object(pairs) => Self::Object(pairs.into_iter().map(|(k, v)| (k, v.into())).collect()), - } - } -} diff --git a/query-engine/core/src/query_document/selection.rs b/query-engine/core/src/query_document/selection.rs index ba102902497a..db27bd0c3c5e 100644 --- a/query-engine/core/src/query_document/selection.rs +++ b/query-engine/core/src/query_document/selection.rs @@ -1,16 +1,16 @@ -use super::QueryValue; +use super::PrismaValue; use indexmap::IndexMap; use itertools::Itertools; use schema_builder::constants::filters; use std::borrow::Cow; -pub type SelectionArgument = (String, QueryValue); +pub type SelectionArgument = (String, PrismaValue); #[derive(Debug, Clone)] pub struct Selection { name: String, alias: Option, - arguments: Vec<(String, QueryValue)>, + arguments: Vec<(String, PrismaValue)>, nested_selections: Vec, } @@ -61,15 +61,15 @@ impl Selection { self.name.starts_with("findUnique") } - pub fn arguments(&self) -> &[(String, QueryValue)] { + pub fn arguments(&self) -> &[(String, PrismaValue)] { &self.arguments } - pub fn pop_argument(&mut self) -> Option<(String, QueryValue)> { + pub fn pop_argument(&mut self) -> Option<(String, PrismaValue)> { self.arguments.pop() } - pub fn push_argument(&mut self, key: impl Into, arg: impl Into) { + pub fn push_argument(&mut self, key: impl Into, arg: impl Into) { self.arguments.push((key.into(), arg.into())); } @@ -104,8 +104,8 @@ impl Selection { #[derive(Debug, Clone, PartialEq)] pub enum SelectionSet<'a> { - Single(Cow<'a, str>, Vec), - Multi(Vec>>, Vec>), + Single(Cow<'a, str>, Vec), + Multi(Vec>>, Vec>), Empty, } @@ -120,7 +120,7 @@ impl<'a> SelectionSet<'a> { Self::default() } - pub fn push(self, column: impl Into>, value: QueryValue) -> Self { + pub fn push(self, column: impl Into>, value: PrismaValue) -> Self { let column = column.into(); match self { @@ -192,7 +192,7 @@ impl<'a> In<'a> { } } -impl<'a> From> for QueryValue { +impl<'a> From> for PrismaValue { fn from(other: In<'a>) -> Self { match other.selection_set { SelectionSet::Multi(key_sets, val_sets) => { @@ -212,22 +212,13 @@ impl<'a> From> for QueryValue { acc.or(ands) }); - QueryValue::from(conjuctive) + PrismaValue::from(conjuctive) } - SelectionSet::Single(key, vals) => { - let mut argument = IndexMap::new(); - argument.insert( - key.to_string(), - QueryValue::Object( - vec![(filters::IN.to_owned(), QueryValue::List(vals))] - .into_iter() - .collect(), - ), - ); - - QueryValue::Object(argument) - } - SelectionSet::Empty => QueryValue::Null, + SelectionSet::Single(key, vals) => PrismaValue::Object(vec![( + key.to_string(), + PrismaValue::Object(vec![(filters::IN.to_owned(), PrismaValue::List(vals))]), + )]), + SelectionSet::Empty => PrismaValue::Null, } } } @@ -236,12 +227,12 @@ impl<'a> From> for QueryValue { pub enum Conjuctive { Or(Vec), And(Vec), - Single(IndexMap), + Single(IndexMap), None, } -impl From> for Conjuctive { - fn from(map: IndexMap) -> Self { +impl From> for Conjuctive { + fn from(map: IndexMap) -> Self { Self::Single(map) } } @@ -280,38 +271,33 @@ impl Conjuctive { } } -impl From for QueryValue { +impl From for PrismaValue { fn from(conjuctive: Conjuctive) -> Self { match conjuctive { Conjuctive::None => Self::Null, - Conjuctive::Single(obj) => QueryValue::Object(single_to_multi_filter(obj)), // QueryValue::Object(obj), + Conjuctive::Single(obj) => PrismaValue::Object(single_to_multi_filter(obj)), // PrismaValue::Object(obj), Conjuctive::Or(conjuctives) => { - let conditions: Vec = conjuctives.into_iter().map(QueryValue::from).collect(); - - let mut map = IndexMap::new(); - map.insert("OR".to_string(), QueryValue::List(conditions)); + let conditions: Vec = conjuctives.into_iter().map(PrismaValue::from).collect(); - QueryValue::Object(map) + PrismaValue::Object(vec![("OR".to_string(), PrismaValue::List(conditions))]) } Conjuctive::And(conjuctives) => { - let conditions: Vec = conjuctives.into_iter().map(QueryValue::from).collect(); + let conditions: Vec = conjuctives.into_iter().map(PrismaValue::from).collect(); - let mut map = IndexMap::new(); - map.insert("AND".to_string(), QueryValue::List(conditions)); - - QueryValue::Object(map) + PrismaValue::Object(vec![("AND".to_string(), PrismaValue::List(conditions))]) } } } } /// Syntax for single-record and multi-record queries -fn single_to_multi_filter(obj: IndexMap) -> IndexMap { - let mut new_obj = IndexMap::new(); +fn single_to_multi_filter(obj: IndexMap) -> Vec<(String, PrismaValue)> { + let mut new_obj = vec![]; for (key, value) in obj { let equality_obj = vec![(filters::EQUALS.to_owned(), value)].into_iter().collect(); - new_obj.insert(key, QueryValue::Object(equality_obj)); + + new_obj.push((key, PrismaValue::Object(equality_obj))); } new_obj diff --git a/query-engine/core/src/response_ir/mod.rs b/query-engine/core/src/response_ir/mod.rs index c2d8dc386380..a8bcf2ad35ec 100644 --- a/query-engine/core/src/response_ir/mod.rs +++ b/query-engine/core/src/response_ir/mod.rs @@ -12,7 +12,6 @@ mod internal; mod ir_serializer; mod response; -use crate::QueryValue; use indexmap::IndexMap; use prisma_models::PrismaValue; use serde::ser::{Serialize, SerializeMap, SerializeSeq, Serializer}; @@ -44,15 +43,14 @@ impl List { self.len() == 0 } - pub fn index_by(self, keys: &[String]) -> Vec<(HashMap, Map)> { - let mut map: Vec<(HashMap, Map)> = Vec::with_capacity(self.len()); + pub fn index_by(self, keys: &[String]) -> Vec<(HashMap, Map)> { + let mut map: Vec<(HashMap, Map)> = Vec::with_capacity(self.len()); for item in self.into_iter() { let inner = item.into_map().unwrap(); - let key: HashMap = keys + let key: HashMap = keys .iter() .map(|key| (key.clone(), inner.get(key).unwrap().clone().into_value().unwrap())) - .map(|(key, val)| (key, QueryValue::from(val))) .collect(); map.push((key, inner)); diff --git a/query-engine/request-handlers/src/graphql/handler.rs b/query-engine/request-handlers/src/graphql/handler.rs index addae388ee74..8514acc6f11e 100644 --- a/query-engine/request-handlers/src/graphql/handler.rs +++ b/query-engine/request-handlers/src/graphql/handler.rs @@ -2,12 +2,18 @@ use super::{GQLBatchResponse, GQLResponse, GraphQlBody}; use crate::PrismaResponse; use futures::FutureExt; use indexmap::IndexMap; +use psl::dml::{ + prisma_value::{parse_datetime, stringify_datetime}, + PrismaValue, +}; use query_core::{ response_ir::{Item, ResponseData}, schema::QuerySchemaRef, BatchDocument, BatchDocumentTransaction, CompactedDocument, Operation, QueryDocument, QueryExecutor, TxId, }; -use std::{fmt, panic::AssertUnwindSafe}; +use std::{collections::HashMap, fmt, panic::AssertUnwindSafe}; + +type ArgsToResult = (HashMap, IndexMap); pub struct GraphQlHandler<'a> { executor: &'a (dyn QueryExecutor + Send + Sync + 'a), @@ -126,7 +132,7 @@ impl<'a> GraphQlHandler<'a> { // At this point, many findUnique queries were converted to a single findMany query and that query was run. // This means we have a list of results and we need to map each result back to their original findUnique query. - // `data` is the piece of logic that allows us to do that mapping. + // `args_to_results` is the data-structure that allows us to do that mapping. // It takes the findMany response and converts it to a map of arguments to result. // Let's take an example. Given the following batched queries: // [ @@ -146,7 +152,7 @@ impl<'a> GraphQlHandler<'a> { // findUnique(where: { id: 1, name: "Bob" }) { id name age } -> { id: 1, name: "Bob", age: 18 } // findUnique(where: { id: 2, name: "Alice" }) { id name age } -> { id: 2, name: "Alice", age: 27 } // ] - let args_to_results = gql_response + let args_to_results: Vec = gql_response .take_data(plural_name) .unwrap() .into_list() @@ -157,13 +163,12 @@ impl<'a> GraphQlHandler<'a> { .into_iter() .map(|args| { let mut responses = GQLResponse::with_capacity(1); - // This is step 5 of the comment above. // Copying here is mandatory due to some of the queries // might be repeated with the same arguments in the original // batch. We need to give the same answer for both of them. - match args_to_results.iter().find(|(a, _)| *a == args) { - Some((_, result)) => { + match Self::find_original_result_from_args(&args_to_results, &args) { + Some(result) => { // Filter out all the keys not selected in the // original query. let result: IndexMap = result @@ -206,4 +211,36 @@ impl<'a> GraphQlHandler<'a> { .execute(tx_id, query_doc, self.query_schema.clone(), trace_id) .await } + + fn find_original_result_from_args<'b>( + args_to_results: &'b [ArgsToResult], + input_args: &'b HashMap, + ) -> Option<&'b IndexMap> { + args_to_results + .iter() + .find(|(arg_from_result, _)| Self::compare_args(arg_from_result, input_args)) + .map(|(_, result)| result) + } + + fn compare_args(left: &HashMap, right: &HashMap) -> bool { + left.iter().all(|(key, left_value)| { + right + .get(key) + .map_or(false, |right_value| Self::compare_values(left_value, right_value)) + }) + } + + /// Compares two PrismaValues but treats DateTime and String as equal when their parsed/stringified versions are equal. + /// We need this when comparing user-inputted values with query response values in the context of compacted queries. + /// User-inputted datetimes are coerced as `PrismaValue::DateTime` but response (and thus serialized) datetimes are `PrismaValue::String`. + /// This should likely _not_ be used outside of this specific context. + fn compare_values(left: &PrismaValue, right: &PrismaValue) -> bool { + match (left, right) { + (PrismaValue::String(t1), PrismaValue::DateTime(t2)) + | (PrismaValue::DateTime(t2), PrismaValue::String(t1)) => parse_datetime(t1) + .map(|t1| &t1 == t2) + .unwrap_or_else(|_| t1 == stringify_datetime(t2).as_str()), + (left, right) => left == right, + } + } } diff --git a/query-engine/request-handlers/src/graphql/protocol_adapter.rs b/query-engine/request-handlers/src/graphql/protocol_adapter.rs index 47064da67225..ca554bf90201 100644 --- a/query-engine/request-handlers/src/graphql/protocol_adapter.rs +++ b/query-engine/request-handlers/src/graphql/protocol_adapter.rs @@ -3,7 +3,7 @@ use bigdecimal::{BigDecimal, FromPrimitive}; use graphql_parser::query::{ Definition, Document, OperationDefinition, Selection as GqlSelection, SelectionSet, Value, }; -use indexmap::IndexMap; +use psl::dml::PrismaValue; use query_core::query_document::*; /// Protocol adapter for GraphQL -> Query Document. @@ -13,7 +13,7 @@ use query_core::query_document::*; /// - Every field of a single `mutation { ... }` is mapped to an `Operation::Write`. /// - If the JSON payload specifies an operation name, only that specific operation is picked and the rest ignored. /// - Fields on the queries are mapped to `Field`s, including arguments. -/// - Concrete values (e.g. in arguments) are mapped to `QueryValue`s. +/// - Concrete values (e.g. in arguments) are mapped to `PrismaValue`s. /// /// Currently unsupported features: /// - Fragments in any form. @@ -95,7 +95,7 @@ impl GraphQLProtocolAdapter { .into_iter() .map(|item| match item { GqlSelection::Field(f) => { - let arguments: Vec<(String, QueryValue)> = f + let arguments: Vec<(String, PrismaValue)> = f .arguments .into_iter() .map(|(k, v)| Ok((k, Self::convert_value(v)?))) @@ -133,42 +133,42 @@ impl GraphQLProtocolAdapter { } } - fn convert_value(value: Value) -> crate::Result { + fn convert_value(value: Value) -> crate::Result { match value { Value::Variable(name) => Err(HandlerError::unsupported_feature( "Variable usage", format!("Variable '{}'.", name), )), Value::Int(i) => match i.as_i64() { - Some(i) => Ok(QueryValue::Int(i)), + Some(i) => Ok(PrismaValue::Int(i)), None => Err(HandlerError::query_conversion(format!( "Invalid 64 bit integer: {:?}", i ))), }, Value::Float(f) => match BigDecimal::from_f64(f) { - Some(dec) => Ok(QueryValue::Float(dec)), + Some(dec) => Ok(PrismaValue::Float(dec)), None => Err(HandlerError::query_conversion(format!("invalid 64-bit float: {:?}", f))), }, - Value::String(s) => Ok(QueryValue::String(s)), - Value::Boolean(b) => Ok(QueryValue::Boolean(b)), - Value::Null => Ok(QueryValue::Null), - Value::Enum(e) => Ok(QueryValue::Enum(e)), + Value::String(s) => Ok(PrismaValue::String(s)), + Value::Boolean(b) => Ok(PrismaValue::Boolean(b)), + Value::Null => Ok(PrismaValue::Null), + Value::Enum(e) => Ok(PrismaValue::Enum(e)), Value::List(values) => { - let values: Vec = values + let values: Vec = values .into_iter() .map(Self::convert_value) - .collect::>>()?; + .collect::>>()?; - Ok(QueryValue::List(values)) + Ok(PrismaValue::List(values)) } Value::Object(map) => { let values = map .into_iter() .map(|(k, v)| Self::convert_value(v).map(|v| (k, v))) - .collect::>>()?; + .collect::>>()?; - Ok(QueryValue::Object(values)) + Ok(PrismaValue::Object(values)) } } } @@ -199,10 +199,10 @@ mod tests { let read = operation.into_read().unwrap(); - let where_args = QueryValue::Object(IndexMap::from([( + let where_args = PrismaValue::Object(vec![( "a_number".to_string(), - QueryValue::Object([("gte".to_string(), QueryValue::Int(1))].into()), - )])); + PrismaValue::Object([("gte".to_string(), PrismaValue::Int(1))].into()), + )]); assert_eq!(read.arguments(), [("where".to_string(), where_args)]); @@ -238,18 +238,18 @@ mod tests { let write = operation.into_write().unwrap(); - let data_args = QueryValue::Object( + let data_args = PrismaValue::Object( [ - ("id".to_string(), QueryValue::Int(1)), + ("id".to_string(), PrismaValue::Int(1)), ( "categories".to_string(), - QueryValue::Object( + PrismaValue::Object( [( "create".to_string(), - QueryValue::List( + PrismaValue::List( [ - QueryValue::Object([("id".to_string(), QueryValue::Int(1))].into()), - QueryValue::Object([("id".to_string(), QueryValue::Int(2))].into()), + PrismaValue::Object([("id".to_string(), PrismaValue::Int(1))].into()), + PrismaValue::Object([("id".to_string(), PrismaValue::Int(2))].into()), ] .into(), ),