Skip to content

Commit

Permalink
fix: wrong label map key type (#322)
Browse files Browse the repository at this point in the history
When reading the json file via
`UNSTRUCTURED_DEFAULT_MODEL_INITIALIZE_PARAMS_JSON_PATH` environment
variable, the label_map keys have to be converted to `int` type before
it can be referenced. Added code to always convert label_map keys to
`int` type.
  • Loading branch information
erjieyong authored Mar 18, 2024
1 parent fc3f38c commit 935f610
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 21 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

* bug: check for None in Chipper bounding box reduction
* chore: removes `install-detectron2` from the `Makefile`
* fix: convert label_map keys read from os.environment `UNSTRUCTURED_DEFAULT_MODEL_INITIALIZE_PARAMS_JSON_PATH` to int type

## 0.7.24

Expand Down
24 changes: 18 additions & 6 deletions test_unstructured_inference/models/test_model.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
from typing import Any
from unittest import mock

Expand Down Expand Up @@ -45,7 +46,9 @@ def test_model_initializes_once():
from unstructured_inference.inference import layout

with mock.patch.dict(models.model_class_map, {"yolox": MockModel}), mock.patch.object(
models, "models", {}
models,
"models",
{},
):
doc = layout.DocumentLayout.from_file("sample-docs/loremipsum.pdf")
doc.pages[0].detection_model.initializer.assert_called_once()
Expand Down Expand Up @@ -143,23 +146,32 @@ def test_env_variables_override_default_model(monkeypatch):
# args, we should get back the model the env var calls for
monkeypatch.setattr(models, "models", {})
with mock.patch.dict(
models.os.environ, {"UNSTRUCTURED_DEFAULT_MODEL_NAME": "checkbox"}
models.os.environ,
{"UNSTRUCTURED_DEFAULT_MODEL_NAME": "checkbox"},
), mock.patch.dict(models.model_class_map, {"checkbox": MockModel}):
model = models.get_model()
assert isinstance(model, MockModel)


def test_env_variables_override_intialization_params(monkeypatch):
def test_env_variables_override_initialization_params(monkeypatch):
# When initialization params are specified in an environment variable, and we call get_model, we
# should see that the model was initialized with those params
monkeypatch.setattr(models, "models", {})
fake_label_map = {"1": "label1", "2": "label2"}
with mock.patch.dict(
models.os.environ,
{"UNSTRUCTURED_DEFAULT_MODEL_INITIALIZE_PARAMS_JSON_PATH": "fake_json.json"},
), mock.patch.object(models, "DEFAULT_MODEL", "fake"), mock.patch.dict(
models.model_class_map, {"fake": mock.MagicMock()}
models.model_class_map,
{"fake": mock.MagicMock()},
), mock.patch(
"builtins.open", mock.mock_open(read_data='{"date": "3/26/81"}')
"builtins.open",
mock.mock_open(
read_data='{"model_path": "fakepath", "label_map": ' + json.dumps(fake_label_map) + "}",
),
):
model = models.get_model()
model.initialize.assert_called_once_with(date="3/26/81")
model.initialize.assert_called_once_with(
model_path="fakepath",
label_map={1: "label1", 2: "label2"},
)
24 changes: 9 additions & 15 deletions unstructured_inference/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,25 +7,15 @@
from unstructured_inference.models.detectron2 import (
MODEL_TYPES as DETECTRON2_MODEL_TYPES,
)
from unstructured_inference.models.detectron2 import (
UnstructuredDetectronModel,
)
from unstructured_inference.models.detectron2 import UnstructuredDetectronModel
from unstructured_inference.models.detectron2onnx import (
MODEL_TYPES as DETECTRON2_ONNX_MODEL_TYPES,
)
from unstructured_inference.models.detectron2onnx import (
UnstructuredDetectronONNXModel,
)
from unstructured_inference.models.super_gradients import (
UnstructuredSuperGradients,
)
from unstructured_inference.models.detectron2onnx import UnstructuredDetectronONNXModel
from unstructured_inference.models.super_gradients import UnstructuredSuperGradients
from unstructured_inference.models.unstructuredmodel import UnstructuredModel
from unstructured_inference.models.yolox import (
MODEL_TYPES as YOLOX_MODEL_TYPES,
)
from unstructured_inference.models.yolox import (
UnstructuredYoloXModel,
)
from unstructured_inference.models.yolox import MODEL_TYPES as YOLOX_MODEL_TYPES
from unstructured_inference.models.yolox import UnstructuredYoloXModel

DEFAULT_MODEL = "yolox"

Expand Down Expand Up @@ -58,6 +48,10 @@ def get_model(model_name: Optional[str] = None) -> UnstructuredModel:
if initialize_param_json is not None:
with open(initialize_param_json) as fp:
initialize_params = json.load(fp)
label_map_int_keys = {
int(key): value for key, value in initialize_params["label_map"].items()
}
initialize_params["label_map"] = label_map_int_keys
else:
if model_name in DETECTRON2_MODEL_TYPES:
initialize_params = DETECTRON2_MODEL_TYPES[model_name]
Expand Down

0 comments on commit 935f610

Please sign in to comment.