Skip to content

Commit

Permalink
Support overriding just the validator class
Browse files Browse the repository at this point in the history
  • Loading branch information
edgarrmondragon committed Apr 24, 2023
1 parent 1bef5f5 commit 5791b62
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 39 deletions.
48 changes: 14 additions & 34 deletions singer_sdk/sinks/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from typing import IO, TYPE_CHECKING, Any, Mapping, Sequence

from dateutil import parser
from jsonschema import Draft7Validator, FormatChecker, Validator
from jsonschema import Draft7Validator, Validator

from singer_sdk.helpers._batch import (
BaseBatchFileEncoding,
Expand Down Expand Up @@ -91,54 +91,34 @@ def __init__(
def get_record_validator(self, schema: dict) -> Validator:
"""Get JSON schema validator for a given schema.
Override this method to customize the JSON schema validator.
Override this method to customize the JSON schema validator or to use a
different string format checker.
Args:
schema: JSON schema to validate records against.
Returns:
A ``jsonschema`` `validator`_.
A ``jsonschema`` `validator`_ instance.
.. _validator:
https://python-jsonschema.readthedocs.io/en/stable/api/jsonschema/validators/
"""
return Draft7Validator( # type: ignore[return-value]
schema,
format_checker=self.get_format_checker(),
)

def get_format_checker(self) -> FormatChecker:
"""Get format checker for JSON schema.
Override this method to add custom string format checkers to the JSON schema
validator.
This is useful when, for example, the target requires a specific format for a
date or datetime field.
cls = self.get_record_validator_class()
return cls(schema, format_checker=cls.FORMAT_CHECKER)

Example:
.. code-block:: python
def get_record_validator_class(self) -> type[Validator]:
"""Get JSON schema validator class.
def get_format_checker(self):
format_checker = super().get_format_checker()
def is_simple_date(value):
try:
datetime.datetime.strptime(value, "%Y-%m-%d")
except ValueError:
return False
return True
format_checker.checks("date")(is_simple_date)
return format_checker
Defaults to JSON schema Draft 7. Override this method to customize the JSON
schema validator.
Returns:
An instance of `jsonschema.FormatChecker`_.
A ``jsonschema`` `validator`_ class.
.. _jsonschema.FormatChecker:
https://python-jsonschema.readthedocs.io/en/stable/validate/#jsonschema.FormatChecker
.. _validator:
https://python-jsonschema.readthedocs.io/en/stable/api/jsonschema/validators/
"""
return FormatChecker(formats=()) # Don't check any formats by default
return Draft7Validator # type: ignore[return-value]

def _get_context(self, record: dict) -> dict: # noqa: ARG002
"""Return an empty dictionary by default.
Expand Down
Empty file added tests/core/sinks/__init__.py
Empty file.
11 changes: 6 additions & 5 deletions tests/core/sinks/test_format_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from singer_sdk.target_base import Target

ISOFORMAT = "%Y-%m-%dT%H:%M:%S.%f%z"
UTC = datetime.timezone.utc


@pytest.fixture(scope="module")
Expand Down Expand Up @@ -42,18 +43,18 @@ def process_record(self, record: dict, context: dict) -> None:
@pytest.fixture(scope="module")
def default_checker(default_sink: Sink) -> FormatChecker:
"""Return a default format checker."""
return default_sink.get_format_checker()
return default_sink.get_record_validator_class().FORMAT_CHECKER


@pytest.fixture(scope="module")
def datetime_checker(default_sink: Sink) -> FormatChecker:
"""Return a custom 'date-time' format checker."""
checker = default_sink.get_format_checker()
checker = default_sink.get_record_validator_class().FORMAT_CHECKER

@checker.checks("date-time", raises=ValueError)
def check_time(instance: object) -> bool:
try:
datetime.datetime.strptime(instance, ISOFORMAT)
datetime.datetime.strptime(instance, ISOFORMAT).replace(tzinfo=UTC)
except ValueError:
return False
return True
Expand All @@ -62,7 +63,7 @@ def check_time(instance: object) -> bool:


@pytest.mark.parametrize(
("fmt", "value"),
"fmt,value",
[
pytest.param(
"any string",
Expand All @@ -84,7 +85,7 @@ def test_default_checks(default_checker: FormatChecker, value: str, fmt: str):


@pytest.mark.parametrize(
("fmt", "value"),
"fmt,value",
[
pytest.param(
"date-time",
Expand Down

0 comments on commit 5791b62

Please sign in to comment.