Skip to content

Commit

Permalink
Header dictionary pass through and BaseRester nesting fix (#715)
Browse files Browse the repository at this point in the history
* Import MaterialsRester instead of MPRester in base

* Pull data for new doc id

* Add header passthrough to rester

* Pass headers to session creation

* Fix default headers value

* More default header fixes

* More variable fixes
  • Loading branch information
Jason Munro authored Dec 6, 2022
1 parent 49fc74a commit 6bf1791
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 107 deletions.
156 changes: 51 additions & 105 deletions mp_api/client/core/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def __init__(
monty_decode: bool = True,
use_document_model: bool = True,
timeout: int = 20,
headers: dict = None,
):
"""
Args:
Expand Down Expand Up @@ -89,6 +90,7 @@ def __init__(
as a dictionary. This can be simpler to work with but bypasses data validation
and will not give auto-complete for available fields.
timeout: Time in seconds to wait until a request timeout error is thrown
headers (dict): Custom headers for localhost connections.
"""

self.api_key = api_key
Expand All @@ -99,6 +101,7 @@ def __init__(
self.monty_decode = monty_decode
self.use_document_model = use_document_model
self.timeout = timeout
self.headers = headers or {}

if self.suffix:
self.endpoint = urljoin(self.endpoint, self.suffix)
Expand All @@ -117,20 +120,20 @@ def __init__(
@property
def session(self) -> requests.Session:
if not self._session:
self._session = self._create_session(self.api_key, self.include_user_agent)
self._session = self._create_session(self.api_key, self.include_user_agent, self.headers)
return self._session

@staticmethod
def _create_session(api_key, include_user_agent):
def _create_session(api_key, include_user_agent, headers):
session = requests.Session()
session.headers = {"x-api-key": api_key}
session.headers.update(headers)

if include_user_agent:
pymatgen_info = "pymatgen/" + pmg_version
python_info = f"Python/{sys.version.split()[0]}"
platform_info = f"{platform.system()}/{platform.release()}"
session.headers[
"user-agent"
] = f"{pymatgen_info} ({python_info} {platform_info})"
session.headers["user-agent"] = f"{pymatgen_info} ({python_info} {platform_info})"

max_retry_num = MAPIClientSettings().MAX_RETRIES
retry = Retry(
Expand Down Expand Up @@ -219,9 +222,7 @@ def _post_resource(
message = data
else:
try:
message = ", ".join(
f"{entry['loc'][1]} - {entry['msg']}" for entry in data
)
message = ", ".join(f"{entry['loc'][1]} - {entry['msg']}" for entry in data)
except (KeyError, IndexError):
message = str(data)

Expand Down Expand Up @@ -352,17 +353,13 @@ def _submit_requests(
url_string += f"{key}={parsed_val}&"

bare_url_len = len(url_string)
max_param_str_length = (
MAPIClientSettings().MAX_HTTP_URL_LENGTH - bare_url_len
)
max_param_str_length = MAPIClientSettings().MAX_HTTP_URL_LENGTH - bare_url_len

# Next, check if default number of parallel requests works.
# If not, make slice size the minimum number of param entries
# contained in any substring of length max_param_str_length.
param_length = len(criteria[parallel_param].split(","))
slice_size = (
int(param_length / MAPIClientSettings().NUM_PARALLEL_REQUESTS) or 1
)
slice_size = int(param_length / MAPIClientSettings().NUM_PARALLEL_REQUESTS) or 1

url_param_string = quote(criteria[parallel_param])

Expand All @@ -374,9 +371,7 @@ def _submit_requests(

if len(parallel_param_str_chunks) > 0:

params_min_chunk = min(
parallel_param_str_chunks, key=lambda x: len(x.split("%2C"))
)
params_min_chunk = min(parallel_param_str_chunks, key=lambda x: len(x.split("%2C")))

num_params_min_chunk = len(params_min_chunk.split("%2C"))

Expand Down Expand Up @@ -406,11 +401,7 @@ def _submit_requests(
# Split list and generate multiple criteria
new_criteria = [
{
**{
key: criteria[key]
for key in criteria
if key not in [parallel_param, "_limit"]
},
**{key: criteria[key] for key in criteria if key not in [parallel_param, "_limit"]},
parallel_param: ",".join(list_chunk),
"_limit": new_limits[list_num],
}
Expand All @@ -433,13 +424,9 @@ def _submit_requests(
subtotals = []
remaining_docs_avail = {}

initial_params_list = [
{"url": url, "verify": True, "params": copy(crit)} for crit in new_criteria
]
initial_params_list = [{"url": url, "verify": True, "params": copy(crit)} for crit in new_criteria]

initial_data_tuples = self._multi_thread(
use_document_model, initial_params_list
)
initial_data_tuples = self._multi_thread(use_document_model, initial_params_list)

for data, subtotal, crit_ind in initial_data_tuples:

Expand All @@ -452,9 +439,7 @@ def _submit_requests(

# Rebalance if some parallel queries produced too few results
if len(remaining_docs_avail) > 1 and len(total_data["data"]) < chunk_size:
remaining_docs_avail = dict(
sorted(remaining_docs_avail.items(), key=lambda item: item[1])
)
remaining_docs_avail = dict(sorted(remaining_docs_avail.items(), key=lambda item: item[1]))

# Redistribute missing docs from initial chunk among queries
# which have head room with respect to remaining document number.
Expand All @@ -481,19 +466,15 @@ def _submit_requests(
new_limits[crit_ind] += fill_docs
fill_docs = 0

rebalance_params.append(
{"url": url, "verify": True, "params": copy(crit)}
)
rebalance_params.append({"url": url, "verify": True, "params": copy(crit)})

new_criteria[crit_ind]["_skip"] += crit["_limit"]
new_criteria[crit_ind]["_limit"] = chunk_size

# Obtain missing initial data after rebalancing
if len(rebalance_params) > 0:

rebalance_data_tuples = self._multi_thread(
use_document_model, rebalance_params
)
rebalance_data_tuples = self._multi_thread(use_document_model, rebalance_params)

for data, _, _ in rebalance_data_tuples:
total_data["data"].extend(data["data"])
Expand All @@ -507,9 +488,7 @@ def _submit_requests(
total_data["meta"] = last_data_entry["meta"]

# Get max number of response pages
max_pages = (
num_chunks if num_chunks is not None else ceil(total_num_docs / chunk_size)
)
max_pages = num_chunks if num_chunks is not None else ceil(total_num_docs / chunk_size)

# Get total number of docs needed
num_docs_needed = min((max_pages * chunk_size), total_num_docs)
Expand Down Expand Up @@ -625,22 +604,16 @@ def _multi_thread(

return_data = []

params_gen = iter(
params_list
) # Iter necessary for islice to keep track of what has been accessed
params_gen = iter(params_list) # Iter necessary for islice to keep track of what has been accessed

params_ind = 0

with ThreadPoolExecutor(
max_workers=MAPIClientSettings().NUM_PARALLEL_REQUESTS
) as executor:
with ThreadPoolExecutor(max_workers=MAPIClientSettings().NUM_PARALLEL_REQUESTS) as executor:

# Get list of initial futures defined by max number of parallel requests
futures = set()

for params in itertools.islice(
params_gen, MAPIClientSettings().NUM_PARALLEL_REQUESTS
):
for params in itertools.islice(params_gen, MAPIClientSettings().NUM_PARALLEL_REQUESTS):

future = executor.submit(
self._submit_request_and_process,
Expand Down Expand Up @@ -702,13 +675,9 @@ def _submit_request_and_process(
Tuple with data and total number of docs in matching the query in the database.
"""
try:
response = self.session.get(
url=url, verify=verify, params=params, timeout=timeout
)
response = self.session.get(url=url, verify=verify, params=params, timeout=timeout)
except requests.exceptions.ConnectTimeout:
raise MPRestError(
f"REST query timed out on URL {url}. Try again with a smaller request."
)
raise MPRestError(f"REST query timed out on URL {url}. Try again with a smaller request.")

if response.status_code == 200:

Expand All @@ -724,18 +693,10 @@ def _submit_request_and_process(
raw_doc_list = [self.document_model.parse_obj(d) for d in data["data"]] # type: ignore

if len(raw_doc_list) > 0:
data_model, set_fields, _ = self._generate_returned_model(
raw_doc_list[0]
)
data_model, set_fields, _ = self._generate_returned_model(raw_doc_list[0])

data["data"] = [
data_model(
**{
field: value
for field, value in raw_doc.dict().items()
if field in set_fields
}
)
data_model(**{field: value for field, value in raw_doc.dict().items() if field in set_fields})
for raw_doc in raw_doc_list
]

Expand All @@ -754,9 +715,7 @@ def _submit_request_and_process(
message = data
else:
try:
message = ", ".join(
f"{entry['loc'][1]} - {entry['msg']}" for entry in data
)
message = ", ".join(f"{entry['loc'][1]} - {entry['msg']}" for entry in data)
except (KeyError, IndexError):
message = str(data)

Expand All @@ -767,9 +726,7 @@ def _submit_request_and_process(

def _generate_returned_model(self, doc):

set_fields = [
field for field, _ in doc if field in doc.dict(exclude_unset=True)
]
set_fields = [field for field, _ in doc if field in doc.dict(exclude_unset=True)]
unset_fields = [field for field in doc.__fields__ if field not in set_fields]

data_model = create_model(
Expand All @@ -779,19 +736,12 @@ def _generate_returned_model(self, doc):
)

data_model.__fields__ = {
**{
name: description
for name, description in data_model.__fields__.items()
if name in set_fields
},
**{name: description for name, description in data_model.__fields__.items() if name in set_fields},
"fields_not_requested": data_model.__fields__["fields_not_requested"],
}

def new_repr(self) -> str:
extra = ",\n".join(
f"\033[1m{n}\033[0;0m={getattr(self, n)!r}"
for n in data_model.__fields__
)
extra = ",\n".join(f"\033[1m{n}\033[0;0m={getattr(self, n)!r}" for n in data_model.__fields__)

s = f"\033[4m\033[1m{self.__class__.__name__}<{self.__class__.__base__.__name__}>\033[0;0m\033[0;0m(\n{extra}\n)" # noqa: E501
return s
Expand All @@ -813,9 +763,7 @@ def new_getattr(self, attr) -> str:
" A full list of unrequested fields can be found in `fields_not_requested`."
)
else:
raise AttributeError(
f"{self.__class__.__name__!r} object has no attribute {attr!r}"
)
raise AttributeError(f"{self.__class__.__name__!r} object has no attribute {attr!r}")

data_model.__repr__ = new_repr
data_model.__str__ = new_str
Expand Down Expand Up @@ -872,10 +820,7 @@ def get_data_by_id(
"""

if document_id is None:
raise ValueError(
"Please supply a specific ID. You can use the query method to find "
"ids of interest."
)
raise ValueError("Please supply a specific ID. You can use the query method to find " "ids of interest.")

if self.primary_key in ["material_id", "task_id"]:
validate_ids([document_id])
Expand All @@ -897,28 +842,31 @@ def get_data_by_id(
if self.primary_key == "material_id":
# see if the material_id has changed, perhaps a task_id was supplied
# this should likely be re-thought
from mp_api.client import MPRester
from mp_api.client.routes.materials import MaterialsRester

with MPRester(api_key=self.api_key, endpoint=self.base_endpoint) as mpr:
new_document_id = mpr.get_materials_id_from_task_id(document_id)
with MaterialsRester(
api_key=self.api_key, endpoint=self.base_endpoint, use_document_model=False, monty_decode=False
) as mpr:
docs = mpr.search(task_ids=[document_id], fields=["material_id"])

if new_document_id is not None:
warnings.warn(
f"Document primary key has changed from {document_id} to {new_document_id}, "
f"returning data for {new_document_id} in {self.suffix} route. "
)
document_id = new_document_id
if len(docs) > 0:

results = self._query_resource_data(
criteria=criteria, fields=fields, suburl=document_id # type: ignore
)
new_document_id = docs[0].get("material_id", None)

if new_document_id is not None:
warnings.warn(
f"Document primary key has changed from {document_id} to {new_document_id}, "
f"returning data for {new_document_id} in {self.suffix} route. "
)

results = self._query_resource_data(
criteria=criteria, fields=fields, suburl=new_document_id # type: ignore
)

if not results:
raise MPRestError(f"No result for record {document_id}.")
elif len(results) > 1: # pragma: no cover
raise ValueError(
f"Multiple records for {document_id}, this shouldn't happen. Please report as a bug."
)
raise ValueError(f"Multiple records for {document_id}, this shouldn't happen. Please report as a bug.")
else:
return results[0]

Expand Down Expand Up @@ -1025,9 +973,7 @@ def count(self, criteria: Optional[Dict] = None) -> Union[int, str]:
False,
False,
) # do not waste cycles decoding
results = self._query_resource(
criteria=criteria, num_chunks=1, chunk_size=1
)
results = self._query_resource(criteria=criteria, num_chunks=1, chunk_size=1)
self.monty_decode, self.use_document_model = user_preferences
return results["meta"]["total_doc"]
except Exception: # pragma: no cover
Expand Down
9 changes: 7 additions & 2 deletions mp_api/client/mprester.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ def __init__(
include_user_agent=True,
monty_decode: bool = True,
use_document_model: bool = True,
headers: dict = None,
):
"""
Args:
Expand Down Expand Up @@ -115,6 +116,7 @@ def __init__(
use_document_model: If False, skip the creating the document model and return data
as a dictionary. This can be simpler to work with but bypasses data validation
and will not give auto-complete for available fields.
headers (dict): Custom headers for localhost connections.
"""

if api_key and len(api_key) == 16:
Expand All @@ -126,14 +128,17 @@ def __init__(

self.api_key = api_key
self.endpoint = endpoint
self.session = BaseRester._create_session(api_key=api_key, include_user_agent=include_user_agent)
self.headers = headers or {}
self.session = BaseRester._create_session(
api_key=api_key, include_user_agent=include_user_agent, headers=self.headers
)
self.use_document_model = use_document_model
self.monty_decode = monty_decode

try:
from mpcontribs.client import Client

self.contribs = Client(api_key)
self.contribs = Client(api_key, headers=self.headers)
except ImportError:
self.contribs = None
warnings.warn(
Expand Down

0 comments on commit 6bf1791

Please sign in to comment.