Skip to content

Commit

Permalink
Improvements in union matching logic during validation (#1332)
Browse files Browse the repository at this point in the history
  • Loading branch information
sydney-runkle authored Jun 19, 2024
1 parent 46379ac commit fcc77f8
Show file tree
Hide file tree
Showing 6 changed files with 551 additions and 26 deletions.
14 changes: 12 additions & 2 deletions src/validators/dataclass.rs
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ impl Validator for DataclassArgsValidator {
let mut used_keys: AHashSet<&str> = AHashSet::with_capacity(self.fields.len());

let state = &mut state.rebind_extra(|extra| extra.data = Some(output_dict.clone()));
let mut fields_set_count: usize = 0;

macro_rules! set_item {
($field:ident, $value:expr) => {{
Expand All @@ -175,6 +176,7 @@ impl Validator for DataclassArgsValidator {
Ok(Some(value)) => {
// Default value exists, and passed validation if required
set_item!(field, value);
fields_set_count += 1;
}
Ok(None) | Err(ValError::Omit) => continue,
// Note: this will always use the field name even if there is an alias
Expand Down Expand Up @@ -214,15 +216,21 @@ impl Validator for DataclassArgsValidator {
}
// found a positional argument, validate it
(Some(pos_value), None) => match field.validator.validate(py, pos_value.borrow_input(), state) {
Ok(value) => set_item!(field, value),
Ok(value) => {
set_item!(field, value);
fields_set_count += 1;
}
Err(ValError::LineErrors(line_errors)) => {
errors.extend(line_errors.into_iter().map(|err| err.with_outer_location(index)));
}
Err(err) => return Err(err),
},
// found a keyword argument, validate it
(None, Some((lookup_path, kw_value))) => match field.validator.validate(py, kw_value, state) {
Ok(value) => set_item!(field, value),
Ok(value) => {
set_item!(field, value);
fields_set_count += 1;
}
Err(ValError::LineErrors(line_errors)) => {
errors.extend(
line_errors
Expand Down Expand Up @@ -336,6 +344,8 @@ impl Validator for DataclassArgsValidator {
}
}

state.add_fields_set(fields_set_count);

if errors.is_empty() {
if let Some(init_only_args) = init_only_args {
Ok((output_dict, PyTuple::new_bound(py, init_only_args)).to_object(py))
Expand Down
12 changes: 9 additions & 3 deletions src/validators/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,7 @@ impl Validator for ModelValidator {
for field_name in validated_fields_set {
fields_set.add(field_name)?;
}
state.add_fields_set(fields_set.len());
}

force_setattr(py, model, intern!(py, DUNDER_DICT), validated_dict.to_object(py))?;
Expand Down Expand Up @@ -241,10 +242,13 @@ impl ModelValidator {
} else {
PySet::new_bound(py, [&String::from(ROOT_FIELD)])?
};
force_setattr(py, self_instance, intern!(py, DUNDER_FIELDS_SET_KEY), fields_set)?;
force_setattr(py, self_instance, intern!(py, DUNDER_FIELDS_SET_KEY), &fields_set)?;
force_setattr(py, self_instance, intern!(py, ROOT_FIELD), &output)?;
state.add_fields_set(fields_set.len());
} else {
let (model_dict, model_extra, fields_set) = output.extract(py)?;
let (model_dict, model_extra, fields_set): (Bound<PyAny>, Bound<PyAny>, Bound<PyAny>) =
output.extract(py)?;
state.add_fields_set(fields_set.len().unwrap_or(0));
set_model_attrs(self_instance, &model_dict, &model_extra, &fields_set)?;
}
self.call_post_init(py, self_instance.clone(), input, state.extra())
Expand Down Expand Up @@ -281,11 +285,13 @@ impl ModelValidator {
} else {
PySet::new_bound(py, [&String::from(ROOT_FIELD)])?
};
force_setattr(py, &instance, intern!(py, DUNDER_FIELDS_SET_KEY), fields_set)?;
force_setattr(py, &instance, intern!(py, DUNDER_FIELDS_SET_KEY), &fields_set)?;
force_setattr(py, &instance, intern!(py, ROOT_FIELD), output)?;
state.add_fields_set(fields_set.len());
} else {
let (model_dict, model_extra, val_fields_set) = output.extract(py)?;
let fields_set = existing_fields_set.unwrap_or(&val_fields_set);
state.add_fields_set(fields_set.len().unwrap_or(0));
set_model_attrs(&instance, &model_dict, &model_extra, fields_set)?;
}
self.call_post_init(py, instance, input, state.extra())
Expand Down
4 changes: 4 additions & 0 deletions src/validators/typed_dict.rs
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,7 @@ impl Validator for TypedDictValidator {

{
let state = &mut state.rebind_extra(|extra| extra.data = Some(output_dict.clone()));
let mut fields_set_count: usize = 0;

for field in &self.fields {
let op_key_value = match dict.get_item(&field.lookup_key) {
Expand All @@ -186,6 +187,7 @@ impl Validator for TypedDictValidator {
match field.validator.validate(py, value.borrow_input(), state) {
Ok(value) => {
output_dict.set_item(&field.name_py, value)?;
fields_set_count += 1;
}
Err(ValError::Omit) => continue,
Err(ValError::LineErrors(line_errors)) => {
Expand Down Expand Up @@ -227,6 +229,8 @@ impl Validator for TypedDictValidator {
Err(err) => return Err(err),
}
}

state.add_fields_set(fields_set_count);
}

if let Some(used_keys) = used_keys {
Expand Down
54 changes: 38 additions & 16 deletions src/validators/union.rs
Original file line number Diff line number Diff line change
Expand Up @@ -108,10 +108,12 @@ impl UnionValidator {
state: &mut ValidationState<'_, 'py>,
) -> ValResult<PyObject> {
let old_exactness = state.exactness;
let old_fields_set_count = state.fields_set_count;

let strict = state.strict_or(self.strict);
let mut errors = MaybeErrors::new(self.custom_error.as_ref());

let mut success = None;
let mut best_match: Option<(Py<PyAny>, Exactness, Option<usize>)> = None;

for (choice, label) in &self.choices {
let state = &mut state.rebind_extra(|extra| {
Expand All @@ -120,47 +122,67 @@ impl UnionValidator {
}
});
state.exactness = Some(Exactness::Exact);
state.fields_set_count = None;
let result = choice.validate(py, input, state);
match result {
Ok(new_success) => match state.exactness {
// exact match, return
Some(Exactness::Exact) => {
Ok(new_success) => match (state.exactness, state.fields_set_count) {
(Some(Exactness::Exact), None) => {
// exact match with no fields set data, return immediately
return {
// exact match, return, restore any previous exactness
state.exactness = old_exactness;
state.fields_set_count = old_fields_set_count;
Ok(new_success)
};
}
_ => {
// success should always have an exactness
debug_assert_ne!(state.exactness, None);

let new_exactness = state.exactness.unwrap_or(Exactness::Lax);
// if the new result has higher exactness than the current success, replace it
if success
.as_ref()
.map_or(true, |(_, current_exactness)| *current_exactness < new_exactness)
{
// TODO: is there a possible optimization here, where once there has
// been one success, we turn on strict mode, to avoid unnecessary
// coercions for further validation?
success = Some((new_success, new_exactness));
let new_fields_set_count = state.fields_set_count;

// we use both the exactness and the fields_set_count to determine the best union member match
// if fields_set_count is available for the current best match and the new candidate, we use this
// as the primary metric. If the new fields_set_count is greater, the new candidate is better.
// if the fields_set_count is the same, we use the exactness as a tie breaker to determine the best match.
// if the fields_set_count is not available for either the current best match or the new candidate,
// we use the exactness to determine the best match.
let new_success_is_best_match: bool =
best_match
.as_ref()
.map_or(true, |(_, cur_exactness, cur_fields_set_count)| {
match (*cur_fields_set_count, new_fields_set_count) {
(Some(cur), Some(new)) if cur != new => cur < new,
_ => *cur_exactness < new_exactness,
}
});

if new_success_is_best_match {
best_match = Some((new_success, new_exactness, new_fields_set_count));
}
}
},
Err(ValError::LineErrors(lines)) => {
// if we don't yet know this validation will succeed, record the error
if success.is_none() {
if best_match.is_none() {
errors.push(choice, label.as_deref(), lines);
}
}
otherwise => return otherwise,
}
}

// restore previous validation state to prepare for any future validations
state.exactness = old_exactness;
state.fields_set_count = old_fields_set_count;

if let Some((success, exactness)) = success {
if let Some((best_match, exactness, fields_set_count)) = best_match {
state.floor_exactness(exactness);
return Ok(success);
if let Some(count) = fields_set_count {
state.add_fields_set(count);
}
return Ok(best_match);
}

// no matches, build errors
Expand Down
6 changes: 6 additions & 0 deletions src/validators/validation_state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ pub enum Exactness {
pub struct ValidationState<'a, 'py> {
pub recursion_guard: &'a mut RecursionState,
pub exactness: Option<Exactness>,
pub fields_set_count: Option<usize>,
// deliberately make Extra readonly
extra: Extra<'a, 'py>,
}
Expand All @@ -27,6 +28,7 @@ impl<'a, 'py> ValidationState<'a, 'py> {
Self {
recursion_guard, // Don't care about exactness unless doing union validation
exactness: None,
fields_set_count: None,
extra,
}
}
Expand Down Expand Up @@ -68,6 +70,10 @@ impl<'a, 'py> ValidationState<'a, 'py> {
}
}

pub fn add_fields_set(&mut self, fields_set_count: usize) {
*self.fields_set_count.get_or_insert(0) += fields_set_count;
}

pub fn cache_str(&self) -> StringCacheMode {
self.extra.cache_str
}
Expand Down
Loading

0 comments on commit fcc77f8

Please sign in to comment.