Skip to content

Commit

Permalink
feat: revert test function
Browse files Browse the repository at this point in the history
  • Loading branch information
christinestraub committed Jan 6, 2025
1 parent 648c24a commit a456d4e
Show file tree
Hide file tree
Showing 6 changed files with 36 additions and 46 deletions.
6 changes: 3 additions & 3 deletions test_unstructured/chunking/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,17 +152,17 @@ class Describe_chunk_elements:
],
)
def it_supports_the_include_orig_elements_option(
self, kwargs: dict[str, Any], expected_value: bool, chunk_elements_: Mock
self, kwargs: dict[str, Any], expected_value: bool, _chunk_elements_: Mock
):
# -- this line would raise if "include_orig_elements" was not an available parameter on
# -- `chunk_elements()`.
chunk_elements([], **kwargs)

_, opts = chunk_elements_.call_args.args
_, opts = _chunk_elements_.call_args.args
assert opts.include_orig_elements is expected_value

# -- fixtures --------------------------------------------------------------------------------

@pytest.fixture()
def chunk_elements_(self, request: FixtureRequest):
def _chunk_elements_(self, request: FixtureRequest):
return function_mock(request, "unstructured.chunking.basic._chunk_elements")
6 changes: 3 additions & 3 deletions test_unstructured/chunking/test_title.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,19 +456,19 @@ class Describe_chunk_by_title:
],
)
def it_supports_the_include_orig_elements_option(
self, kwargs: dict[str, Any], expected_value: bool, chunk_by_title_: Mock
self, kwargs: dict[str, Any], expected_value: bool, _chunk_by_title_: Mock
):
# -- this line would raise if "include_orig_elements" was not an available parameter on
# -- `chunk_by_title()`.
chunk_by_title([], **kwargs)

_, opts = chunk_by_title_.call_args.args
_, opts = _chunk_by_title_.call_args.args
assert opts.include_orig_elements is expected_value

# -- fixtures --------------------------------------------------------------------------------

@pytest.fixture()
def chunk_by_title_(self, request: FixtureRequest):
def _chunk_by_title_(self, request: FixtureRequest):
return function_mock(request, "unstructured.chunking.title._chunk_by_title")


Expand Down
21 changes: 17 additions & 4 deletions test_unstructured/nlp/test_tokenize.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,27 @@
from typing import List, Tuple
from unittest.mock import patch

import nltk

from test_unstructured.nlp.mock_nltk import mock_sent_tokenize, mock_word_tokenize
from unstructured.nlp import tokenize


def test_nltk_assets_validation():
with patch("unstructured.nlp.tokenize.validate_nltk_assets") as mock_validate:
tokenize.validate_nltk_assets()
mock_validate.assert_called_once()
def test_nltk_packages_download_if_not_present():
tokenize._download_nltk_packages_if_not_present.cache_clear()
with patch.object(nltk, "find", side_effect=LookupError):
with patch.object(tokenize, "download_nltk_packages") as mock_download:
tokenize._download_nltk_packages_if_not_present()

mock_download.assert_called_once()


def test_nltk_packages_do_not_download_if():
tokenize._download_nltk_packages_if_not_present.cache_clear()
with patch.object(nltk, "find"), patch.object(nltk, "download") as mock_download:
tokenize._download_nltk_packages_if_not_present()

mock_download.assert_not_called()


def mock_pos_tag(tokens: List[str]) -> List[Tuple[str, str]]:
Expand Down
6 changes: 3 additions & 3 deletions test_unstructured/partition/test_msg.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,10 +388,10 @@ def and_it_uses_the_last_modified_date_from_the_source_file_when_the_message_has
opts_args: dict[str, Any],
filesystem_last_modified: str | None,
Message_sent_date_: Mock,
last_modified_prop_: Mock,
_last_modified_prop_: Mock,
):
Message_sent_date_.return_value = None
last_modified_prop_.return_value = filesystem_last_modified
_last_modified_prop_.return_value = filesystem_last_modified
opts_args["file_path"] = example_doc_path("fake-email.msg")
opts = MsgPartitionerOptions(**opts_args)

Expand Down Expand Up @@ -443,7 +443,7 @@ def it_provides_access_to_pass_through_kwargs_collected_by_the_partitioner_funct
# -- fixtures --------------------------------------------------------------------------------

@pytest.fixture
def last_modified_prop_(self, request: FixtureRequest):
def _last_modified_prop_(self, request: FixtureRequest):
return property_mock(request, MsgPartitionerOptions, "_last_modified")

@pytest.fixture
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,29 +30,29 @@ class DescribeOCRAgent:
"""Unit-test suite for `unstructured.partition.utils...ocr_interface.OCRAgent` class."""

def it_provides_access_to_the_configured_OCR_agent(
self, get_ocr_agent_cls_qname_: Mock, get_instance_: Mock, ocr_agent_: Mock
self, _get_ocr_agent_cls_qname_: Mock, get_instance_: Mock, ocr_agent_: Mock
):
get_ocr_agent_cls_qname_.return_value = OCR_AGENT_TESSERACT
_get_ocr_agent_cls_qname_.return_value = OCR_AGENT_TESSERACT
get_instance_.return_value = ocr_agent_

ocr_agent = OCRAgent.get_agent(language="eng")

get_ocr_agent_cls_qname_.assert_called_once_with()
_get_ocr_agent_cls_qname_.assert_called_once_with()
get_instance_.assert_called_once_with(OCR_AGENT_TESSERACT, "eng")
assert ocr_agent is ocr_agent_

def but_it_raises_when_the_requested_agent_is_not_whitelisted(
self, get_ocr_agent_cls_qname_: Mock
self, _get_ocr_agent_cls_qname_: Mock
):
get_ocr_agent_cls_qname_.return_value = "Invalid.Ocr.Agent.Qname"
_get_ocr_agent_cls_qname_.return_value = "Invalid.Ocr.Agent.Qname"
with pytest.raises(ValueError, match="must be set to a whitelisted module"):
OCRAgent.get_agent(language="eng")

@pytest.mark.parametrize("exception_cls", [ImportError, AttributeError])
def and_it_raises_when_the_requested_agent_cannot_be_loaded(
self, get_ocr_agent_cls_qname_: Mock, exception_cls: type[Exception], _clear_cache
self, _get_ocr_agent_cls_qname_: Mock, exception_cls: type[Exception], _clear_cache
):
get_ocr_agent_cls_qname_.return_value = OCR_AGENT_TESSERACT
_get_ocr_agent_cls_qname_.return_value = OCR_AGENT_TESSERACT
with patch(
"unstructured.partition.utils.ocr_models.ocr_interface.importlib.import_module",
side_effect=exception_cls,
Expand Down Expand Up @@ -99,7 +99,7 @@ def get_instance_(self, request: FixtureRequest):
return method_mock(request, OCRAgent, "get_instance")

@pytest.fixture()
def get_ocr_agent_cls_qname_(self, request: FixtureRequest):
def _get_ocr_agent_cls_qname_(self, request: FixtureRequest):
return method_mock(request, OCRAgent, "_get_ocr_agent_cls_qname")

@pytest.fixture()
Expand Down
27 changes: 2 additions & 25 deletions unstructured/nlp/tokenize.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,31 +11,8 @@

CACHE_MAX_SIZE: Final[int] = 128


def is_ci_environment() -> bool:
"""
Checks if the current environment is a Continuous Integration (CI) environment.
Many CI systems set specific environment variables to indicate they are running in CI mode.
"""
# Common CI environment variables
ci_env_vars = [
"CI", # General CI indicator (e.g., GitHub Actions, GitLab, Travis CI)
"GITHUB_ACTIONS", # GitHub Actions
"GITLAB_CI", # GitLab CI/CD
"CIRCLECI", # CircleCI
"TRAVIS", # Travis CI
"JENKINS_HOME", # Jenkins
"BITBUCKET_BUILD_NUMBER", # Bitbucket Pipelines
]

# Check if any of the CI environment variables are set
return any(var in os.environ for var in ci_env_vars)


if not is_ci_environment():
# Define the NLTK data path based on the Docker image environment
NLTK_DATA_PATH = os.getenv("NLTK_DATA", "/home/notebook-user/nltk_data")
nltk.data.path.append(NLTK_DATA_PATH)
NLTK_DATA_PATH = os.getenv("NLTK_DATA", "/home/notebook-user/nltk_data")
nltk.data.path.append(NLTK_DATA_PATH)


def download_nltk_packages():
Expand Down

0 comments on commit a456d4e

Please sign in to comment.