Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

General repo and tooling cleanup #839

Merged
merged 26 commits into from
Aug 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .github/workflows/testing.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@ on:
pull_request:
branches: [main]

concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true

jobs:
test:
strategy:
Expand Down
20 changes: 14 additions & 6 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,18 +1,26 @@
default_stages: [commit]

default_install_hook_types: [pre-commit, commit-msg]

ci:
autoupdate_commit_msg: "chore: update pre-commit hooks"

repos:
- repo: https://github.com/charliermarsh/ruff-pre-commit
rev: v0.0.261
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.0.284
hooks:
- id: ruff
args: [--fix, --ignore, "D,E501"]
args: [--fix, --show-fixes]

- repo: https://github.com/psf/black
rev: 23.3.0
rev: 23.7.0
hooks:
- id: black

- repo: https://github.com/asottile/blacken-docs
rev: "1.15.0"
hooks:
- id: black-jupyter
- id: blacken-docs
additional_dependencies: [black>=23.7.0]

- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.4.0
Expand Down
2 changes: 2 additions & 0 deletions mp_api/client/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
"""Primary MAPI module."""
from __future__ import annotations

import os
from importlib.metadata import PackageNotFoundError, version

Expand Down
2 changes: 2 additions & 0 deletions mp_api/client/core/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,4 @@
from __future__ import annotations

from .client import BaseRester, MPRestError
from .settings import MAPIClientSettings
92 changes: 48 additions & 44 deletions mp_api/client/core/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
API v3 to enable the creation of data structures and pymatgen objects using
Materials Project data.
"""
from __future__ import annotations

import gzip
import itertools
Expand All @@ -14,7 +15,7 @@
from json import JSONDecodeError
from math import ceil
from os import environ
from typing import Any, Dict, Generic, List, Optional, Tuple, TypeVar, Union
from typing import Any, Generic, TypeVar
from urllib.parse import quote, urljoin

import requests
Expand Down Expand Up @@ -52,18 +53,18 @@
class BaseRester(Generic[T]):
"""Base client class with core stubs."""

suffix: Optional[str] = None
suffix: str | None = None
document_model: BaseModel = None # type: ignore
supports_versions: bool = False
primary_key: str = "material_id"

def __init__(
self,
api_key: Union[str, None] = None,
api_key: str | None = None,
endpoint: str = DEFAULT_ENDPOINT,
include_user_agent: bool = True,
session: Optional[requests.Session] = None,
s3_resource: Optional[Any] = None,
session: requests.Session | None = None,
s3_resource: Any | None = None,
debug: bool = False,
monty_decode: bool = True,
use_document_model: bool = True,
Expand Down Expand Up @@ -191,11 +192,11 @@ def __exit__(self, exc_type, exc_val, exc_tb): # pragma: no cover

def _post_resource(
self,
body: Dict = None,
params: Optional[Dict] = None,
suburl: Optional[str] = None,
use_document_model: Optional[bool] = None,
) -> Dict:
body: dict = None,
params: dict | None = None,
suburl: str | None = None,
use_document_model: bool | None = None,
) -> dict:
"""Post data to the endpoint for a Resource.

Arguments:
Expand Down Expand Up @@ -261,11 +262,11 @@ def _post_resource(

def _patch_resource(
self,
body: Dict = None,
params: Optional[Dict] = None,
suburl: Optional[str] = None,
use_document_model: Optional[bool] = None,
) -> Dict:
body: dict = None,
params: dict | None = None,
suburl: str | None = None,
use_document_model: bool | None = None,
) -> dict:
"""Patch data to the endpoint for a Resource.

Arguments:
Expand Down Expand Up @@ -330,7 +331,7 @@ def _patch_resource(
raise MPRestError(str(ex))

def _query_open_data(self, bucket: str, prefix: str, key: str) -> dict:
"""Query Materials Project AWS open data s3 buckets
"""Query Materials Project AWS open data s3 buckets.

Args:
bucket (str): Materials project bucket name
Expand All @@ -340,7 +341,6 @@ def _query_open_data(self, bucket: str, prefix: str, key: str) -> dict:
Returns:
dict: MontyDecoded data
"""

ref = self.s3_resource.Object(bucket, f"{prefix}/{key}.json.gz") # type: ignore
bytes = ref.get()["Body"] # type: ignore

Expand All @@ -352,15 +352,15 @@ def _query_open_data(self, bucket: str, prefix: str, key: str) -> dict:

def _query_resource(
self,
criteria: Optional[Dict] = None,
fields: Optional[List[str]] = None,
suburl: Optional[str] = None,
use_document_model: Optional[bool] = None,
parallel_param: Optional[str] = None,
num_chunks: Optional[int] = None,
chunk_size: Optional[int] = None,
timeout: Optional[int] = None,
) -> Dict:
criteria: dict | None = None,
fields: list[str] | None = None,
suburl: str | None = None,
use_document_model: bool | None = None,
parallel_param: str | None = None,
num_chunks: int | None = None,
chunk_size: int | None = None,
timeout: int | None = None,
) -> dict:
"""Query the endpoint for a Resource containing a list of documents
and meta information about pagination and total document count.

Expand Down Expand Up @@ -429,7 +429,7 @@ def _submit_requests(
num_chunks=None,
chunk_size=None,
timeout=None,
) -> Dict:
) -> dict:
"""Handle submitting requests. Parallel requests supported if possible.
Parallelization will occur either over the largest list of supported
query parameters used and/or over pagination.
Expand Down Expand Up @@ -712,7 +712,7 @@ def _submit_requests(
def _multi_thread(
self,
use_document_model: bool,
params_list: List[dict],
params_list: list[dict],
progress_bar: tqdm = None,
timeout: int = None,
):
Expand Down Expand Up @@ -788,7 +788,7 @@ def _submit_request_and_process(
params: dict,
use_document_model: bool,
timeout: int = None,
) -> Tuple[Dict, int]:
) -> tuple[dict, int]:
"""Submits GET request and handles the response.

Arguments:
Expand Down Expand Up @@ -936,12 +936,12 @@ def new_dict(self, *args, **kwargs):

def _query_resource_data(
self,
criteria: Optional[Dict] = None,
fields: Optional[List[str]] = None,
suburl: Optional[str] = None,
use_document_model: Optional[bool] = None,
timeout: Optional[int] = None,
) -> Union[List[T], List[Dict]]:
criteria: dict | None = None,
fields: list[str] | None = None,
suburl: str | None = None,
use_document_model: bool | None = None,
timeout: int | None = None,
) -> list[T] | list[dict]:
"""Query the endpoint for a list of documents without associated meta information. Only
returns a single page of results.

Expand All @@ -967,7 +967,7 @@ def _query_resource_data(
def get_data_by_id(
self,
document_id: str,
fields: Optional[List[str]] = None,
fields: list[str] | None = None,
) -> T:
"""Query the endpoint for a single document.

Expand Down Expand Up @@ -1039,12 +1039,12 @@ def get_data_by_id(

def _search(
self,
num_chunks: Optional[int] = None,
num_chunks: int | None = None,
chunk_size: int = 1000,
all_fields: bool = True,
fields: Optional[List[str]] = None,
fields: list[str] | None = None,
**kwargs,
) -> Union[List[T], List[Dict]]:
) -> list[T] | list[dict]:
"""A generic search method to retrieve documents matching specific parameters.

Arguments:
Expand Down Expand Up @@ -1082,7 +1082,7 @@ def _get_all_documents(
fields=None,
chunk_size=1000,
num_chunks=None,
) -> Union[List[T], List[Dict]]:
) -> list[T] | list[dict]:
"""Iterates over pages until all documents are retrieved. Displays
progress using tqdm. This method is designed to give a common
implementation for the search_* methods on various endpoints. See
Expand Down Expand Up @@ -1124,10 +1124,14 @@ def _get_all_documents(

return results["data"]

def count(self, criteria: Optional[Dict] = None) -> Union[int, str]:
def count(self, criteria: dict | None = None) -> int | str:
"""Return a count of total documents.
:param criteria: As in .query()
:return:

Args:
criteria (dict | None): As in .query(). Defaults to None

Returns:
(int | str): Count of total results, or string indicating error
"""
try:
criteria = criteria or {}
Expand All @@ -1145,7 +1149,7 @@ def count(self, criteria: Optional[Dict] = None) -> Union[int, str]:
return "Problem getting count"

@property
def available_fields(self) -> List[str]:
def available_fields(self) -> list[str]:
if self.document_model is None:
return ["Unknown fields."]
return list(self.document_model.schema()["properties"].keys()) # type: ignore
Expand Down
20 changes: 13 additions & 7 deletions mp_api/client/core/utils.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
from __future__ import annotations

import re
from typing import List, Optional, Type, get_args
from typing import get_args

from monty.json import MSONable
from pydantic import BaseModel
from pydantic.schema import get_flat_models_from_model
from pydantic.utils import lenient_issubclass


def validate_ids(id_list: List[str]):
def validate_ids(id_list: list[str]):
"""Function to validate material and task IDs.

Args:
Expand All @@ -29,8 +31,8 @@ def validate_ids(id_list: List[str]):


def api_sanitize(
pydantic_model: Type[BaseModel],
fields_to_leave: Optional[List[str]] = None,
pydantic_model: BaseModel,
fields_to_leave: list[str] | None = None,
allow_dict_msonable=False,
):
"""Function to clean up pydantic models for the API by:
Expand All @@ -40,13 +42,17 @@ def api_sanitize(
WARNING: This works in place, so it mutates the model and all sub-models

Args:
fields_to_leave: list of strings for model fields as "model__name__.field"
pydantic_model (BaseModel): Pydantic model to alter
fields_to_leave (list[str] | None): list of strings for model fields as "model__name__.field".
Defaults to None.
allow_dict_msonable (bool): Whether to allow dictionaries in place of MSONable quantities.
Defaults to False
"""
models = [
model
for model in get_flat_models_from_model(pydantic_model)
if issubclass(model, BaseModel)
] # type: List[Type[BaseModel]]
] # type: list[BaseModel]

fields_to_leave = fields_to_leave or []
fields_tuples = [f.split(".") for f in fields_to_leave]
Expand Down Expand Up @@ -77,7 +83,7 @@ def api_sanitize(
return pydantic_model


def allow_msonable_dict(monty_cls: Type[MSONable]):
def allow_msonable_dict(monty_cls: type[MSONable]):
"""Patch Monty to allow for dict values for MSONable."""

def validate_monty(cls, v):
Expand Down
Loading