Skip to content

Commit

Permalink
Smart chunk with max URL length consideration (#695)
Browse files Browse the repository at this point in the history
* Smart chunk with max URL length

* Linting

* Skip mpcontribs tests

* Add large list test

* Add python 3.10 to test matrix

* Python versions to strings

* Fix min param chunk size

* Fix issue with empty chunk list

* Linting
  • Loading branch information
Jason Munro authored Oct 20, 2022
1 parent 72eafaa commit d3f7687
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 5 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/testing.yml
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,10 @@ jobs:
test:
strategy:
max-parallel: 2
max-parallel: 3
matrix:
os: [ubuntu-latest, macos-latest, windows-latest]
python-version: [3.8, 3.9]
python-version: ["3.8", "3.9", "3.10"]

runs-on: ${{ matrix.os }}

Expand Down
38 changes: 37 additions & 1 deletion mp_api/client/core/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from math import ceil
from os import environ
from typing import Dict, Generic, List, Optional, Tuple, TypeVar, Union
from urllib.parse import urljoin
from urllib.parse import urljoin, quote

import requests
from emmet.core.utils import jsanitize
Expand Down Expand Up @@ -342,11 +342,47 @@ def _submit_requests(
# trying to evenly divide num_chunks by the total number of new
# criteria dicts.
if parallel_param is not None:

# Determine slice size accounting for character maximum in HTTP URL
# First get URl length without parallel param
url_string = ""
for key, value in criteria.items():
if key != parallel_param:
parsed_val = quote(str(value))
url_string += f"{key}={parsed_val}&"

bare_url_len = len(url_string)
max_param_str_length = (
MAPIClientSettings().MAX_HTTP_URL_LENGTH - bare_url_len
)

# Next, check if default number of parallel requests works.
# If not, make slice size the minimum number of param entries
# contained in any substring of length max_param_str_length.
param_length = len(criteria[parallel_param].split(","))
slice_size = (
int(param_length / MAPIClientSettings().NUM_PARALLEL_REQUESTS) or 1
)

url_param_string = quote(criteria[parallel_param])

parallel_param_str_chunks = [
url_param_string[i : i + max_param_str_length]
for i in range(0, len(url_param_string), max_param_str_length)
if (i + max_param_str_length) <= len(url_param_string)
]

if len(parallel_param_str_chunks) > 0:

params_min_chunk = min(
parallel_param_str_chunks, key=lambda x: len(x.split("%2C"))
)

num_params_min_chunk = len(params_min_chunk.split("%2C"))

if num_params_min_chunk < slice_size:
slice_size = num_params_min_chunk or 1

new_param_values = [
entry
for entry in (
Expand Down
11 changes: 9 additions & 2 deletions mp_api/client/core/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,13 +49,20 @@ class MAPIClientSettings(BaseSettings):
)

NUM_PARALLEL_REQUESTS: int = Field(
CPU_COUNT, description="Number of parallel requests to send.",
CPU_COUNT,
description="Number of parallel requests to send.",
)

MAX_RETRIES: int = Field(3, description="Maximum number of retries for requests.")

MUTE_PROGRESS_BARS: bool = Field(
False, description="Whether to mute progress bars when data is retrieved.",
False,
description="Whether to mute progress bars when data is retrieved.",
)

MAX_HTTP_URL_LENGTH: int = Field(
2000,
description="Number of characters to use to define the maximum length of a given HTTP URL.",
)

class Config:
Expand Down
12 changes: 12 additions & 0 deletions tests/test_mprester.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,7 @@ def test_get_entries_in_chemsys(self, mpr):
for e in gibbs_entries:
assert isinstance(e, GibbsComputedStructureEntry)

@pytest.mark.skip(reason="SSL issues")
def test_get_pourbaix_entries(self, mpr):
# test input chemsys as a list of elements
pbx_entries = mpr.get_pourbaix_entries(["Fe", "Cr"])
Expand Down Expand Up @@ -245,6 +246,7 @@ def test_get_pourbaix_entries(self, mpr):
# so4_two_minus = pbx_entries[9]
# self.assertAlmostEqual(so4_two_minus.energy, 0.301511, places=3)

@pytest.mark.skip(reason="SSL issues")
def test_get_ion_entries(self, mpr):
entries = mpr.get_entries_in_chemsys("Ti-O-H")
pd = PhaseDiagram(entries)
Expand Down Expand Up @@ -312,3 +314,13 @@ def test_get_charge_density_data(self, mpr):
def test_get_wulff_shape(self, mpr):
ws = mpr.get_wulff_shape("mp-126")
assert isinstance(ws, WulffShape)

def test_large_list(self, mpr):
mpids = [
str(doc.material_id)
for doc in mpr.summary.search(
chunk_size=1000, num_chunks=15, fields=["material_id"]
)
]
docs = mpr.summary.search(material_ids=mpids, fields=["material_ids"])
assert len(docs) == 15000

0 comments on commit d3f7687

Please sign in to comment.