From ff7f643e3560ed8d78ce52f32ed1a0f552d0cb38 Mon Sep 17 00:00:00 2001 From: Jason Munro Date: Tue, 26 Sep 2023 20:30:47 -0700 Subject: [PATCH] Pydantic 2 support (#847) * Update api utils to work with pydantic 2 * Fix api sanitize * Fix MPDataDoc creation * Lazy load all nested resters * Bump to python>=3.9 * Cache api sanitize * Cache emmet version retrieval * Fix client tests * Migrate __fields__ * Change materials test input * Move flat models util func to emmet * api_sanitize allow_dict behavior * Bump emmet * Update maggma util import * Linting and deprecated changes * Spelling * Final linting * Bump emmet * Linting * Remove repeat code * Linting * Add maggma to deps --- mp_api/client/core/client.py | 22 ++++++------ mp_api/client/core/settings.py | 3 +- mp_api/client/core/utils.py | 35 ++++++++++--------- mp_api/client/mprester.py | 2 +- .../routes/materials/electronic_structure.py | 6 ++-- pyproject.toml | 7 ++-- tests/materials/core_function.py | 2 +- tests/materials/test_electronic_structure.py | 4 +-- tests/molecules/core_function.py | 2 +- tests/test_client.py | 2 +- 10 files changed, 44 insertions(+), 41 deletions(-) diff --git a/mp_api/client/core/client.py b/mp_api/client/core/client.py index 32d5cac6..be864b6f 100644 --- a/mp_api/client/core/client.py +++ b/mp_api/client/core/client.py @@ -842,7 +842,7 @@ def _submit_request_and_process( data_model( **{ field: value - for field, value in raw_doc.dict().items() + for field, value in raw_doc.model_dump().items() if field in set_fields } ) @@ -877,29 +877,29 @@ def _submit_request_and_process( def _generate_returned_model(self, doc): set_fields = [ - field for field, _ in doc if field in doc.dict(exclude_unset=True) + field for field, _ in doc if field in doc.model_dump(exclude_unset=True) ] - unset_fields = [field for field in doc.__fields__ if field not in set_fields] + unset_fields = [field for field in doc.model_fields if field not in set_fields] data_model = create_model( "MPDataDoc", - fields_not_requested=unset_fields, + fields_not_requested=(list[str], unset_fields), __base__=self.document_model, ) - data_model.__fields__ = { + data_model.model_fields = { **{ name: description - for name, description in data_model.__fields__.items() + for name, description in data_model.model_fields.items() if name in set_fields }, - "fields_not_requested": data_model.__fields__["fields_not_requested"], + "fields_not_requested": data_model.model_fields["fields_not_requested"], } def new_repr(self) -> str: extra = ",\n".join( f"\033[1m{n}\033[0;0m={getattr(self, n)!r}" - for n in data_model.__fields__ + for n in data_model.model_fields ) s = f"\033[4m\033[1m{self.__class__.__name__}<{self.__class__.__base__.__name__}>\033[0;0m\033[0;0m(\n{extra}\n)" # noqa: E501 @@ -908,7 +908,7 @@ def new_repr(self) -> str: def new_str(self) -> str: extra = ",\n".join( f"\033[1m{n}\033[0;0m={getattr(self, n)!r}" - for n in data_model.__fields__ + for n in data_model.model_fields if n != "fields_not_requested" ) @@ -927,7 +927,7 @@ def new_getattr(self, attr) -> str: ) def new_dict(self, *args, **kwargs): - d = super(data_model, self).dict(*args, **kwargs) + d = super(data_model, self).model_dump(*args, **kwargs) return jsanitize(d) data_model.__repr__ = new_repr @@ -1155,7 +1155,7 @@ def count(self, criteria: dict | None = None) -> int | str: def available_fields(self) -> list[str]: if self.document_model is None: return ["Unknown fields."] - return list(self.document_model.schema()["properties"].keys()) # type: ignore + return list(self.document_model.model_json_schema()["properties"].keys()) # type: ignore def __repr__(self): # pragma: no cover return f"<{self.__class__.__name__} {self.endpoint}>" diff --git a/mp_api/client/core/settings.py b/mp_api/client/core/settings.py index 2cc1ccca..2ee3a2ab 100644 --- a/mp_api/client/core/settings.py +++ b/mp_api/client/core/settings.py @@ -2,7 +2,8 @@ from multiprocessing import cpu_count from typing import List -from pydantic import BaseSettings, Field +from pydantic import Field +from pydantic_settings import BaseSettings from pymatgen.core import _load_pmg_settings from mp_api.client import __file__ as root_dir diff --git a/mp_api/client/core/utils.py b/mp_api/client/core/utils.py index 417cfb8a..17f2c59f 100644 --- a/mp_api/client/core/utils.py +++ b/mp_api/client/core/utils.py @@ -2,12 +2,13 @@ import re from functools import cache -from typing import get_args +from typing import Optional, get_args +from maggma.utils import get_flat_models_from_model from monty.json import MSONable from pydantic import BaseModel -from pydantic.schema import get_flat_models_from_model -from pydantic.utils import lenient_issubclass +from pydantic._internal._utils import lenient_issubclass +from pydantic.fields import FieldInfo def validate_ids(id_list: list[str]): @@ -62,25 +63,25 @@ def api_sanitize( for model in models: model_fields_to_leave = {f[1] for f in fields_tuples if model.__name__ == f[0]} - for name, field in model.__fields__.items(): - field_type = field.type_ - - if name not in model_fields_to_leave: - field.required = False - field.default = None - field.default_factory = None - field.allow_none = True - field.field_info.default = None - field.field_info.default_factory = None + for name in model.model_fields: + field = model.model_fields[name] + field_type = field.annotation if field_type is not None and allow_dict_msonable: if lenient_issubclass(field_type, MSONable): - field.type_ = allow_msonable_dict(field_type) + field_type = allow_msonable_dict(field_type) else: for sub_type in get_args(field_type): if lenient_issubclass(sub_type, MSONable): allow_msonable_dict(sub_type) - field.populate_validators() + + if name not in model_fields_to_leave: + new_field = FieldInfo.from_annotated_attribute( + Optional[field_type], None + ) + model.model_fields[name] = new_field + + model.model_rebuild(force=True) return pydantic_model @@ -88,7 +89,7 @@ def api_sanitize( def allow_msonable_dict(monty_cls: type[MSONable]): """Patch Monty to allow for dict values for MSONable.""" - def validate_monty(cls, v): + def validate_monty(cls, v, _): """Stub validator for MSONable as a dictionary only.""" if isinstance(v, cls): return v @@ -110,6 +111,6 @@ def validate_monty(cls, v): else: raise ValueError(f"Must provide {cls.__name__} or MSONable dictionary") - monty_cls.validate_monty = classmethod(validate_monty) + monty_cls.validate_monty_v2 = classmethod(validate_monty) return monty_cls diff --git a/mp_api/client/mprester.py b/mp_api/client/mprester.py index d9fb0711..13350805 100644 --- a/mp_api/client/mprester.py +++ b/mp_api/client/mprester.py @@ -748,7 +748,7 @@ def get_entries( if property_data: for property in property_data: entry_dict["data"][property] = ( - doc.dict()[property] + doc.model_dump()[property] if self.use_document_model else doc[property] ) diff --git a/mp_api/client/routes/materials/electronic_structure.py b/mp_api/client/routes/materials/electronic_structure.py index 5d37f0ea..94e220da 100644 --- a/mp_api/client/routes/materials/electronic_structure.py +++ b/mp_api/client/routes/materials/electronic_structure.py @@ -285,7 +285,7 @@ def get_bandstructure_from_material_id( f"No {path_type.value} band structure data found for {material_id}" ) else: - bs_data = bs_data.dict() + bs_data = bs_data.model_dump() if bs_data.get(path_type.value, None): bs_task_id = bs_data[path_type.value]["task_id"] @@ -303,7 +303,7 @@ def get_bandstructure_from_material_id( f"No uniform band structure data found for {material_id}" ) else: - bs_data = bs_data.dict() + bs_data = bs_data.model_dump() if bs_data.get("total", None): bs_task_id = bs_data["total"]["1"]["task_id"] @@ -444,7 +444,7 @@ def get_dos_from_material_id(self, material_id: str): dos_data = es_rester.get_data_by_id( document_id=material_id, fields=["dos"] - ).dict() + ).model_dump() if dos_data["dos"]: dos_task_id = dos_data["dos"]["total"]["1"]["task_id"] diff --git a/pyproject.toml b/pyproject.toml index 2d7097f9..5febbcc3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,16 +22,17 @@ classifiers = [ dependencies = [ "setuptools", "msgpack", + "maggma", "pymatgen>=2022.3.7", "typing-extensions>=3.7.4.1", "requests>=2.23.0", - "monty>=2021.3.12", - "emmet-core>=0.54.0", + "monty>=2023.9.25", + "emmet-core>=0.69.2", ] dynamic = ["version"] [project.optional-dependencies] -all = ["emmet-core[all]>=0.54.0", "custodian", "mpcontribs-client", "boto3"] +all = ["emmet-core[all]>=0.69.1", "custodian", "mpcontribs-client", "boto3"] test = [ "pre-commit", "pytest", diff --git a/tests/materials/core_function.py b/tests/materials/core_function.py index 12871d2b..19c2ac2b 100644 --- a/tests/materials/core_function.py +++ b/tests/materials/core_function.py @@ -60,7 +60,7 @@ def client_search_testing( "num_chunks": 1, } - doc = search_method(**q)[0].dict() + doc = search_method(**q)[0].model_dump() for sub_field in sub_doc_fields: if sub_field in doc: diff --git a/tests/materials/test_electronic_structure.py b/tests/materials/test_electronic_structure.py index 5bf08135..c1d91f81 100644 --- a/tests/materials/test_electronic_structure.py +++ b/tests/materials/test_electronic_structure.py @@ -92,7 +92,7 @@ def test_bs_client(bs_rester): "chunk_size": 1, "num_chunks": 1, } - doc = search_method(**q)[0].dict() + doc = search_method(**q)[0].model_dump() for sub_field in bs_sub_doc_fields: if sub_field in doc: @@ -137,7 +137,7 @@ def test_dos_client(dos_rester): "chunk_size": 1, "num_chunks": 1, } - doc = search_method(**q)[0].dict() + doc = search_method(**q)[0].model_dump() for sub_field in dos_sub_doc_fields: if sub_field in doc: doc = doc[sub_field] diff --git a/tests/molecules/core_function.py b/tests/molecules/core_function.py index 9fac742f..557cafb3 100644 --- a/tests/molecules/core_function.py +++ b/tests/molecules/core_function.py @@ -58,7 +58,7 @@ def client_search_testing( docs = search_method(**q) if len(docs) > 0: - doc = docs[0].dict() + doc = docs[0].model_dump() else: raise ValueError("No documents returned") diff --git a/tests/test_client.py b/tests/test_client.py index b4f11781..f3b9ac9c 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -78,7 +78,7 @@ def test_generic_get_methods(rester): if name not in search_only_resters: doc = rester.get_data_by_id( - doc.dict()[rester.primary_key], fields=[rester.primary_key] + doc.model_dump()[rester.primary_key], fields=[rester.primary_key] ) assert isinstance(doc, rester.document_model)