From e4de8a68e2d9a30974fff258e3de9a20a00955c3 Mon Sep 17 00:00:00 2001 From: David Hewitt Date: Thu, 21 Nov 2024 23:42:43 +0000 Subject: [PATCH] Fix performance regression for JSON tagged union (#1552) --- src/serializers/type_serializers/union.rs | 133 ++++++++++------------ 1 file changed, 58 insertions(+), 75 deletions(-) diff --git a/src/serializers/type_serializers/union.rs b/src/serializers/type_serializers/union.rs index 40a72e9dd..2bf19cff9 100644 --- a/src/serializers/type_serializers/union.rs +++ b/src/serializers/type_serializers/union.rs @@ -117,44 +117,6 @@ fn union_serialize( Ok(None) } -fn tagged_union_serialize( - discriminator_value: Option>, - lookup: &HashMap, - // if this returns `Ok(v)`, we picked a union variant to serialize, where - // `S` is intermediate state which can be passed on to the finalizer - mut selector: impl FnMut(&CombinedSerializer, &Extra) -> PyResult, - extra: &Extra, - choices: &[CombinedSerializer], - retry_with_lax_check: bool, -) -> PyResult> { - let mut new_extra = extra.clone(); - new_extra.check = SerCheck::Strict; - - if let Some(tag) = discriminator_value { - let tag_str = tag.to_string(); - if let Some(&serializer_index) = lookup.get(&tag_str) { - let selected_serializer = &choices[serializer_index]; - - match selector(selected_serializer, &new_extra) { - Ok(v) => return Ok(Some(v)), - Err(_) => { - if retry_with_lax_check { - new_extra.check = SerCheck::Lax; - if let Ok(v) = selector(selected_serializer, &new_extra) { - return Ok(Some(v)); - } - } - } - } - } - } - - // if we haven't returned at this point, we should fallback to the union serializer - // which preserves the historical expectation that we do our best with serialization - // even if that means we resort to inference - union_serialize(selector, extra, choices, retry_with_lax_check) -} - impl TypeSerializer for UnionSerializer { fn to_python( &self, @@ -267,27 +229,21 @@ impl TypeSerializer for TaggedUnionSerializer { exclude: Option<&Bound<'_, PyAny>>, extra: &Extra, ) -> PyResult { - tagged_union_serialize( - self.get_discriminator_value(value, extra), - &self.lookup, + self.tagged_union_serialize( + value, |comb_serializer: &CombinedSerializer, new_extra: &Extra| { comb_serializer.to_python(value, include, exclude, new_extra) }, extra, - &self.choices, - self.retry_with_lax_check(), )? .map_or_else(|| infer_to_python(value, include, exclude, extra), Ok) } fn json_key<'a>(&self, key: &'a Bound<'_, PyAny>, extra: &Extra) -> PyResult> { - tagged_union_serialize( - self.get_discriminator_value(key, extra), - &self.lookup, + self.tagged_union_serialize( + key, |comb_serializer: &CombinedSerializer, new_extra: &Extra| comb_serializer.json_key(key, new_extra), extra, - &self.choices, - self.retry_with_lax_check(), )? .map_or_else(|| infer_json_key(key, extra), Ok) } @@ -300,15 +256,12 @@ impl TypeSerializer for TaggedUnionSerializer { exclude: Option<&Bound<'_, PyAny>>, extra: &Extra, ) -> Result { - match tagged_union_serialize( - None, - &self.lookup, + match self.tagged_union_serialize( + value, |comb_serializer: &CombinedSerializer, new_extra: &Extra| { comb_serializer.to_python(value, include, exclude, new_extra) }, extra, - &self.choices, - self.retry_with_lax_check(), ) { Ok(Some(v)) => return infer_serialize(v.bind(value.py()), serializer, None, None, extra), Ok(None) => infer_serialize(value, serializer, include, exclude, extra), @@ -326,36 +279,66 @@ impl TypeSerializer for TaggedUnionSerializer { } impl TaggedUnionSerializer { - fn get_discriminator_value(&self, value: &Bound<'_, PyAny>, extra: &Extra) -> Option> { + fn get_discriminator_value<'py>(&self, value: &Bound<'py, PyAny>) -> Option> { let py = value.py(); - let discriminator_value = match &self.discriminator { + match &self.discriminator { Discriminator::LookupKey(lookup_key) => { // we're pretty lax here, we allow either dict[key] or object.key, as we very well could // be doing a discriminator lookup on a typed dict, and there's no good way to check that // at this point. we could be more strict and only do this in lax mode... - let getattr_result = match value.is_instance_of::() { - true => { - let value_dict = value.downcast::().unwrap(); - lookup_key.py_get_dict_item(value_dict).ok() - } - false => lookup_key.simple_py_get_attr(value).ok(), - }; - getattr_result.and_then(|opt| opt.map(|(_, bound)| bound.to_object(py))) + if let Ok(value_dict) = value.downcast::() { + lookup_key.py_get_dict_item(value_dict).ok().flatten() + } else { + lookup_key.simple_py_get_attr(value).ok().flatten() + } + .map(|(_, tag)| tag) } - Discriminator::Function(func) => func.call1(py, (value,)).ok(), - }; - if discriminator_value.is_none() { - let value_str = truncate_safe_repr(value, None); + Discriminator::Function(func) => func.bind(py).call1((value,)).ok(), + } + } - // If extra.check is SerCheck::None, we're in a top-level union. We should thus raise this warning - if extra.check == SerCheck::None { - extra.warnings.custom_warning( - format!( - "Failed to get discriminator value for tagged union serialization with value `{value_str}` - defaulting to left to right union serialization." - ) - ); + fn tagged_union_serialize( + &self, + value: &Bound<'_, PyAny>, + // if this returns `Ok(v)`, we picked a union variant to serialize, where + // `S` is intermediate state which can be passed on to the finalizer + mut selector: impl FnMut(&CombinedSerializer, &Extra) -> PyResult, + extra: &Extra, + ) -> PyResult> { + if let Some(tag) = self.get_discriminator_value(value) { + let mut new_extra = extra.clone(); + new_extra.check = SerCheck::Strict; + + let tag_str = tag.to_string(); + if let Some(&serializer_index) = self.lookup.get(&tag_str) { + let selected_serializer = &self.choices[serializer_index]; + + match selector(selected_serializer, &new_extra) { + Ok(v) => return Ok(Some(v)), + Err(_) => { + if self.retry_with_lax_check() { + new_extra.check = SerCheck::Lax; + if let Ok(v) = selector(selected_serializer, &new_extra) { + return Ok(Some(v)); + } + } + } + } } + } else if extra.check == SerCheck::None { + // If extra.check is SerCheck::None, we're in a top-level union. We should thus raise + // this warning + let value_str = truncate_safe_repr(value, None); + extra.warnings.custom_warning( + format!( + "Failed to get discriminator value for tagged union serialization with value `{value_str}` - defaulting to left to right union serialization." + ) + ); } - discriminator_value + + // if we haven't returned at this point, we should fallback to the union serializer + // which preserves the historical expectation that we do our best with serialization + // even if that means we resort to inference + union_serialize(selector, extra, &self.choices, self.retry_with_lax_check()) } }