Skip to content

Commit

Permalink
Refactor/get model allow registering of new models with function calls (
Browse files Browse the repository at this point in the history
#333)

This PR allows us to register new models and use them in the system:
- refactor of `get_model` so it relies on information stored in
`model_class_map` and `model_config_map` to initialize a new model with
a given model name (which is the key to both mappings)
- a new function `unstructured_inference.models.base.register_new_model`
now allows adding new model definition to the class mapping and config
mapping
- after calling register new model one can now call `get_model` with the
new model name and get the new model

## testing

New unit tests should pass
  • Loading branch information
badGarnet authored Apr 4, 2024
1 parent 0a08377 commit 147c5b1
Show file tree
Hide file tree
Showing 6 changed files with 68 additions and 20 deletions.
4 changes: 3 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
## 0.7.26-dev0
## 0.7.26

* feat: add a set of new `ElementType`s to extend future element types recognition
* feat: allow registering of new models for inference using `unstructured_inference.models.base.register_new_model` function

## 0.7.25

Expand Down
23 changes: 23 additions & 0 deletions test_unstructured_inference/models/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,35 @@ def predict(self, x: Any) -> Any:
return []


MOCK_MODEL_TYPES = {
"foo": {
"input_shape": (640, 640),
},
}


def test_get_model(monkeypatch):
monkeypatch.setattr(models, "models", {})
with mock.patch.dict(models.model_class_map, {"checkbox": MockModel}):
assert isinstance(models.get_model("checkbox"), MockModel)


def test_register_new_model():
assert "foo" not in models.model_class_map
assert "foo" not in models.model_config_map
models.register_new_model(MOCK_MODEL_TYPES, MockModel)
assert "foo" in models.model_class_map
assert "foo" in models.model_config_map
model = models.get_model("foo")
assert len(model.initializer.mock_calls) == 1
assert model.initializer.mock_calls[0][-1] == MOCK_MODEL_TYPES["foo"]
assert isinstance(model, MockModel)
# unregister the new model by reset to default
models.model_class_map, models.model_config_map = models.get_default_model_mappings()
assert "foo" not in models.model_class_map
assert "foo" not in models.model_config_map


def test_raises_invalid_model():
with pytest.raises(models.UnknownModelException):
models.get_model("fake_model")
Expand Down
2 changes: 1 addition & 1 deletion unstructured_inference/__version__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.7.26-dev0" # pragma: no cover
__version__ = "0.7.26" # pragma: no cover
1 change: 1 addition & 0 deletions unstructured_inference/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
settings that should not be altered without making a code change (e.g., definition of 1Gb of memory
in bytes). Constants should go into `./constants.py`
"""

import os
from dataclasses import dataclass

Expand Down
8 changes: 5 additions & 3 deletions unstructured_inference/inference/layoutelement.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,9 +165,11 @@ def merge_inferred_layout_with_extracted_layout(
categorized_extracted_elements_to_add = [
LayoutElement(
text=el.text,
type=ElementType.IMAGE
if isinstance(el, ImageTextRegion)
else ElementType.UNCATEGORIZED_TEXT,
type=(
ElementType.IMAGE
if isinstance(el, ImageTextRegion)
else ElementType.UNCATEGORIZED_TEXT
),
source=el.source,
bbox=el.bbox,
)
Expand Down
50 changes: 35 additions & 15 deletions unstructured_inference/models/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from __future__ import annotations

import json
import os
from typing import Dict, Optional, Type
from typing import Dict, Optional, Tuple, Type

from unstructured_inference.models.chipper import MODEL_TYPES as CHIPPER_MODEL_TYPES
from unstructured_inference.models.chipper import UnstructuredChipperModel
Expand All @@ -15,17 +17,41 @@
from unstructured_inference.models.unstructuredmodel import UnstructuredModel
from unstructured_inference.models.yolox import MODEL_TYPES as YOLOX_MODEL_TYPES
from unstructured_inference.models.yolox import UnstructuredYoloXModel
from unstructured_inference.utils import LazyDict

DEFAULT_MODEL = "yolox"

models: Dict[str, UnstructuredModel] = {}

model_class_map: Dict[str, Type[UnstructuredModel]] = {
**{name: UnstructuredDetectronModel for name in DETECTRON2_MODEL_TYPES},
**{name: UnstructuredDetectronONNXModel for name in DETECTRON2_ONNX_MODEL_TYPES},
**{name: UnstructuredYoloXModel for name in YOLOX_MODEL_TYPES},
**{name: UnstructuredChipperModel for name in CHIPPER_MODEL_TYPES},
}

def get_default_model_mappings() -> (
Tuple[
Dict[str, Type[UnstructuredModel]],
Dict[str, dict | LazyDict],
]
):
"""default model mappings for models that are in `unstructured_inference` repo"""
return {
**{name: UnstructuredDetectronModel for name in DETECTRON2_MODEL_TYPES},
**{name: UnstructuredDetectronONNXModel for name in DETECTRON2_ONNX_MODEL_TYPES},
**{name: UnstructuredYoloXModel for name in YOLOX_MODEL_TYPES},
**{name: UnstructuredChipperModel for name in CHIPPER_MODEL_TYPES},
}, {
**DETECTRON2_MODEL_TYPES,
**DETECTRON2_ONNX_MODEL_TYPES,
**YOLOX_MODEL_TYPES,
**CHIPPER_MODEL_TYPES,
}


model_class_map, model_config_map = get_default_model_mappings()


def register_new_model(model_config: dict, model_class: UnstructuredModel):
"""registering a new model by updating the model_config_map and model_class_map with the new
model class information"""
model_config_map.update(model_config)
model_class_map.update({name: model_class for name in model_config})


def get_model(model_name: Optional[str] = None) -> UnstructuredModel:
Expand All @@ -51,14 +77,8 @@ def get_model(model_name: Optional[str] = None) -> UnstructuredModel:
}
initialize_params["label_map"] = label_map_int_keys
else:
if model_name in DETECTRON2_MODEL_TYPES:
initialize_params = DETECTRON2_MODEL_TYPES[model_name]
elif model_name in DETECTRON2_ONNX_MODEL_TYPES:
initialize_params = DETECTRON2_ONNX_MODEL_TYPES[model_name]
elif model_name in YOLOX_MODEL_TYPES:
initialize_params = YOLOX_MODEL_TYPES[model_name]
elif model_name in CHIPPER_MODEL_TYPES:
initialize_params = CHIPPER_MODEL_TYPES[model_name]
if model_name in model_config_map:
initialize_params = model_config_map[model_name]
else:
raise UnknownModelException(f"Unknown model type: {model_name}")

Expand Down

0 comments on commit 147c5b1

Please sign in to comment.