From 2a64cf612af63cec99efacf5aa3942bc685fdb46 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Edgar=20Ram=C3=ADrez=20Mondrag=C3=B3n?= <16805946+edgarrmondragon@users.noreply.github.com> Date: Fri, 17 Nov 2023 16:11:20 -0600 Subject: [PATCH] feat: Better error messages when config validation fails (#768) --- samples/sample_tap_sqlite/__init__.py | 1 + samples/sample_target_sqlite/__init__.py | 1 + singer_sdk/exceptions.py | 15 +++ singer_sdk/plugin_base.py | 72 ++++++++----- tests/core/conftest.py | 101 ++++++++++++++++++ tests/core/test_mapper_class.py | 54 ++++++++++ tests/core/test_streams.py | 127 ++++------------------- tests/core/test_tap_class.py | 92 ++++++++++++++++ tests/core/test_target_class.py | 54 ++++++++++ tests/samples/test_target_sqlite.py | 2 +- 10 files changed, 387 insertions(+), 132 deletions(-) create mode 100644 tests/core/conftest.py create mode 100644 tests/core/test_mapper_class.py create mode 100644 tests/core/test_tap_class.py create mode 100644 tests/core/test_target_class.py diff --git a/samples/sample_tap_sqlite/__init__.py b/samples/sample_tap_sqlite/__init__.py index 49b4365f0..3aed5d21d 100644 --- a/samples/sample_tap_sqlite/__init__.py +++ b/samples/sample_tap_sqlite/__init__.py @@ -48,6 +48,7 @@ class SQLiteTap(SQLTap): DB_PATH_CONFIG, th.StringType, description="The path to your SQLite database file(s).", + required=True, examples=["./path/to/my.db", "/absolute/path/to/my.db"], ), ).to_dict() diff --git a/samples/sample_target_sqlite/__init__.py b/samples/sample_target_sqlite/__init__.py index 40384facf..8e43a5e87 100644 --- a/samples/sample_target_sqlite/__init__.py +++ b/samples/sample_target_sqlite/__init__.py @@ -52,6 +52,7 @@ class SQLiteTarget(SQLTarget): DB_PATH_CONFIG, th.StringType, description="The path to your SQLite database file(s).", + required=True, ), ).to_dict() diff --git a/singer_sdk/exceptions.py b/singer_sdk/exceptions.py index 351776291..75135e800 100644 --- a/singer_sdk/exceptions.py +++ b/singer_sdk/exceptions.py @@ -12,6 +12,21 @@ class ConfigValidationError(Exception): """Raised when a user's config settings fail validation.""" + def __init__( + self, + message: str, + *, + errors: list[str] | None = None, + ) -> None: + """Initialize a ConfigValidationError. + + Args: + message: A message describing the error. + errors: A list of errors which caused the validation error. + """ + super().__init__(message) + self.errors = errors or [] + class FatalAPIError(Exception): """Exception raised when a failed request should not be considered retriable.""" diff --git a/singer_sdk/plugin_base.py b/singer_sdk/plugin_base.py index 53e2cd2f2..b4e82296b 100644 --- a/singer_sdk/plugin_base.py +++ b/singer_sdk/plugin_base.py @@ -72,6 +72,43 @@ def __init__(self) -> None: super().__init__("Mapper not initialized. Please call setup_mapper() first.") +class SingerCommand(click.Command): + """Custom click command class for Singer packages.""" + + def __init__( + self, + *args: t.Any, + logger: logging.Logger, + **kwargs: t.Any, + ) -> None: + """Initialize the command. + + Args: + *args: Positional `click.Command` arguments. + logger: A logger instance. + **kwargs: Keyword `click.Command` arguments. + """ + super().__init__(*args, **kwargs) + self.logger = logger + + def invoke(self, ctx: click.Context) -> t.Any: # noqa: ANN401 + """Invoke the command, capturing warnings and logging them. + + Args: + ctx: The `click` context. + + Returns: + The result of the command invocation. + """ + logging.captureWarnings(capture=True) + try: + return super().invoke(ctx) + except ConfigValidationError as exc: + for error in exc.errors: + self.logger.error("Config validation error: %s", error) + sys.exit(1) + + class PluginBase(metaclass=abc.ABCMeta): """Abstract base class for taps.""" @@ -150,12 +187,12 @@ def __init__( if self._is_secret_config(k): config_dict[k] = SecretString(v) self._config = config_dict - self._validate_config(raise_errors=validate_config) - self._mapper: PluginMapper | None = None - metrics._setup_logging(self.config) self.metrics_logger = metrics.get_metrics_logger() + self._validate_config(raise_errors=validate_config) + self._mapper: PluginMapper | None = None + # Initialization timestamp self.__initialized_at = int(time.time() * 1000) @@ -351,27 +388,19 @@ def _is_secret_config(config_key: str) -> bool: """ return is_common_secret_key(config_key) - def _validate_config( - self, - *, - raise_errors: bool = True, - warnings_as_errors: bool = False, - ) -> tuple[list[str], list[str]]: + def _validate_config(self, *, raise_errors: bool = True) -> list[str]: """Validate configuration input against the plugin configuration JSON schema. Args: raise_errors: Flag to throw an exception if any validation errors are found. - warnings_as_errors: Flag to throw an exception if any warnings were emitted. Returns: - A tuple of configuration validation warnings and errors. + A list of validation errors. Raises: ConfigValidationError: If raise_errors is True and validation fails. """ - warnings: list[str] = [] errors: list[str] = [] - log_fn = self.logger.info config_jsonschema = self.config_jsonschema if config_jsonschema: @@ -389,19 +418,11 @@ def _validate_config( f"JSONSchema was: {config_jsonschema}" ) if raise_errors: - raise ConfigValidationError(summary) + raise ConfigValidationError(summary, errors=errors) - log_fn = self.logger.warning - else: - summary = f"Config validation passed with {len(warnings)} warnings." - for warning in warnings: - summary += f"\n{warning}" + self.logger.warning(summary) - if warnings_as_errors and raise_errors and warnings: - msg = f"One or more warnings ocurred during validation: {warnings}" - raise ConfigValidationError(msg) - log_fn(summary) - return warnings, errors + return errors @classmethod def print_version( @@ -555,7 +576,7 @@ def get_singer_command(cls: type[PluginBase]) -> click.Command: Returns: A callable CLI object. """ - return click.Command( + return SingerCommand( name=cls.name, callback=cls.invoke, context_settings={"help_option_names": ["--help"]}, @@ -596,6 +617,7 @@ def get_singer_command(cls: type[PluginBase]) -> click.Command: is_eager=True, ), ], + logger=cls.logger, ) @plugin_cli diff --git a/tests/core/conftest.py b/tests/core/conftest.py new file mode 100644 index 000000000..06355ccfe --- /dev/null +++ b/tests/core/conftest.py @@ -0,0 +1,101 @@ +"""Tap, target and stream test fixtures.""" + +from __future__ import annotations + +import typing as t + +import pendulum +import pytest + +from singer_sdk import Stream, Tap +from singer_sdk.typing import ( + DateTimeType, + IntegerType, + PropertiesList, + Property, + StringType, +) + + +class SimpleTestStream(Stream): + """Test stream class.""" + + name = "test" + schema = PropertiesList( + Property("id", IntegerType, required=True), + Property("value", StringType, required=True), + Property("updatedAt", DateTimeType, required=True), + ).to_dict() + replication_key = "updatedAt" + + def __init__(self, tap: Tap): + """Create a new stream.""" + super().__init__(tap, schema=self.schema, name=self.name) + + def get_records( + self, + context: dict | None, # noqa: ARG002 + ) -> t.Iterable[dict[str, t.Any]]: + """Generate records.""" + yield {"id": 1, "value": "Egypt"} + yield {"id": 2, "value": "Germany"} + yield {"id": 3, "value": "India"} + + +class UnixTimestampIncrementalStream(SimpleTestStream): + name = "unix_ts" + schema = PropertiesList( + Property("id", IntegerType, required=True), + Property("value", StringType, required=True), + Property("updatedAt", IntegerType, required=True), + ).to_dict() + replication_key = "updatedAt" + + +class UnixTimestampIncrementalStream2(UnixTimestampIncrementalStream): + name = "unix_ts_override" + + def compare_start_date(self, value: str, start_date_value: str) -> str: + """Compare a value to a start date value.""" + + start_timestamp = pendulum.parse(start_date_value).format("X") + return max(value, start_timestamp, key=float) + + +class SimpleTestTap(Tap): + """Test tap class.""" + + name = "test-tap" + config_jsonschema = PropertiesList( + Property("username", StringType, required=True), + Property("password", StringType, required=True), + Property("start_date", DateTimeType), + additional_properties=False, + ).to_dict() + + def discover_streams(self) -> list[Stream]: + """List all streams.""" + return [ + SimpleTestStream(self), + UnixTimestampIncrementalStream(self), + UnixTimestampIncrementalStream2(self), + ] + + +@pytest.fixture +def tap_class(): + """Return the tap class.""" + return SimpleTestTap + + +@pytest.fixture +def tap() -> SimpleTestTap: + """Tap instance.""" + return SimpleTestTap( + config={ + "username": "utest", + "password": "ptest", + "start_date": "2021-01-01", + }, + parse_env_config=False, + ) diff --git a/tests/core/test_mapper_class.py b/tests/core/test_mapper_class.py new file mode 100644 index 000000000..0f0c1192a --- /dev/null +++ b/tests/core/test_mapper_class.py @@ -0,0 +1,54 @@ +from __future__ import annotations + +import json +from contextlib import nullcontext + +import pytest +from click.testing import CliRunner + +from samples.sample_mapper.mapper import StreamTransform +from singer_sdk.exceptions import ConfigValidationError + + +@pytest.mark.parametrize( + "config_dict,expectation,errors", + [ + pytest.param( + {}, + pytest.raises(ConfigValidationError, match="Config validation failed"), + ["'stream_maps' is a required property"], + id="missing_stream_maps", + ), + pytest.param( + {"stream_maps": {}}, + nullcontext(), + [], + id="valid_config", + ), + ], +) +def test_config_errors(config_dict: dict, expectation, errors: list[str]): + with expectation as exc: + StreamTransform(config=config_dict, validate_config=True) + + if isinstance(exc, pytest.ExceptionInfo): + assert exc.value.errors == errors + + +def test_cli_help(): + """Test the CLI help message.""" + runner = CliRunner(mix_stderr=False) + result = runner.invoke(StreamTransform.cli, ["--help"]) + assert result.exit_code == 0 + assert "Show this message and exit." in result.output + + +def test_cli_config_validation(tmp_path): + """Test the CLI config validation.""" + runner = CliRunner(mix_stderr=False) + config_path = tmp_path / "config.json" + config_path.write_text(json.dumps({})) + result = runner.invoke(StreamTransform.cli, ["--config", str(config_path)]) + assert result.exit_code == 1 + assert not result.stdout + assert "'stream_maps' is a required property" in result.stderr diff --git a/tests/core/test_streams.py b/tests/core/test_streams.py index a3a451086..8a415e55d 100644 --- a/tests/core/test_streams.py +++ b/tests/core/test_streams.py @@ -16,68 +16,17 @@ from singer_sdk.helpers._classproperty import classproperty from singer_sdk.helpers.jsonpath import _compile_jsonpath, extract_jsonpath from singer_sdk.pagination import first -from singer_sdk.streams.core import ( - REPLICATION_FULL_TABLE, - REPLICATION_INCREMENTAL, - Stream, -) +from singer_sdk.streams.core import REPLICATION_FULL_TABLE, REPLICATION_INCREMENTAL from singer_sdk.streams.graphql import GraphQLStream from singer_sdk.streams.rest import RESTStream -from singer_sdk.tap_base import Tap -from singer_sdk.typing import ( - DateTimeType, - IntegerType, - PropertiesList, - Property, - StringType, -) +from singer_sdk.typing import IntegerType, PropertiesList, Property, StringType +from tests.core.conftest import SimpleTestStream CONFIG_START_DATE = "2021-01-01" - -class SimpleTestStream(Stream): - """Test stream class.""" - - name = "test" - schema = PropertiesList( - Property("id", IntegerType, required=True), - Property("value", StringType, required=True), - Property("updatedAt", DateTimeType, required=True), - ).to_dict() - replication_key = "updatedAt" - - def __init__(self, tap: Tap): - """Create a new stream.""" - super().__init__(tap, schema=self.schema, name=self.name) - - def get_records( - self, - context: dict | None, # noqa: ARG002 - ) -> t.Iterable[dict[str, t.Any]]: - """Generate records.""" - yield {"id": 1, "value": "Egypt"} - yield {"id": 2, "value": "Germany"} - yield {"id": 3, "value": "India"} - - -class UnixTimestampIncrementalStream(SimpleTestStream): - name = "unix_ts" - schema = PropertiesList( - Property("id", IntegerType, required=True), - Property("value", StringType, required=True), - Property("updatedAt", IntegerType, required=True), - ).to_dict() - replication_key = "updatedAt" - - -class UnixTimestampIncrementalStream2(UnixTimestampIncrementalStream): - name = "unix_ts_override" - - def compare_start_date(self, value: str, start_date_value: str) -> str: - """Compare a value to a start date value.""" - - start_timestamp = pendulum.parse(start_date_value).format("X") - return max(value, start_timestamp, key=float) +if t.TYPE_CHECKING: + from singer_sdk import Stream, Tap + from tests.core.conftest import SimpleTestTap class RestTestStream(RESTStream): @@ -124,43 +73,13 @@ class GraphqlTestStream(GraphQLStream): replication_key = "updatedAt" -class SimpleTestTap(Tap): - """Test tap class.""" - - name = "test-tap" - settings_jsonschema = PropertiesList(Property("start_date", DateTimeType)).to_dict() - - def discover_streams(self) -> list[Stream]: - """List all streams.""" - return [ - SimpleTestStream(self), - UnixTimestampIncrementalStream(self), - UnixTimestampIncrementalStream2(self), - ] - - @pytest.fixture -def tap() -> SimpleTestTap: - """Tap instance.""" - return SimpleTestTap( - config={"start_date": CONFIG_START_DATE}, - parse_env_config=False, - ) - - -@pytest.fixture -def stream(tap: SimpleTestTap) -> SimpleTestStream: - """Create a new stream instance.""" - return t.cast(SimpleTestStream, tap.load_streams()[0]) - - -@pytest.fixture -def unix_timestamp_stream(tap: SimpleTestTap) -> UnixTimestampIncrementalStream: +def stream(tap): """Create a new stream instance.""" - return t.cast(UnixTimestampIncrementalStream, tap.load_streams()[1]) + return tap.load_streams()[0] -def test_stream_apply_catalog(stream: SimpleTestStream): +def test_stream_apply_catalog(stream: Stream): """Applying a catalog to a stream should overwrite fields.""" assert stream.primary_keys == [] assert stream.replication_key == "updatedAt" @@ -251,7 +170,7 @@ def test_stream_apply_catalog(stream: SimpleTestStream): ], ) def test_stream_starting_timestamp( - tap: SimpleTestTap, + tap: Tap, stream_name: str, bookmark_value: str, expected_starting_value: t.Any, @@ -353,12 +272,7 @@ class InvalidReplicationKeyStream(SimpleTestStream): "nested_values", ], ) -def test_jsonpath_rest_stream( - tap: SimpleTestTap, - path: str, - content: str, - result: list[dict], -): +def test_jsonpath_rest_stream(tap: Tap, path: str, content: str, result: list[dict]): """Validate records are extracted correctly from the API response.""" fake_response = requests.Response() fake_response._content = str.encode(content) @@ -371,7 +285,7 @@ def test_jsonpath_rest_stream( assert list(records) == result -def test_jsonpath_graphql_stream_default(tap: SimpleTestTap): +def test_jsonpath_graphql_stream_default(tap: Tap): """Validate graphql JSONPath, defaults to the stream name.""" content = """{ "data": { @@ -391,7 +305,7 @@ def test_jsonpath_graphql_stream_default(tap: SimpleTestTap): assert list(records) == [{"id": 1, "value": "abc"}, {"id": 2, "value": "def"}] -def test_jsonpath_graphql_stream_override(tap: SimpleTestTap): +def test_jsonpath_graphql_stream_override(tap: Tap): """Validate graphql jsonpath can be updated.""" content = """[ {"id": 1, "value": "abc"}, @@ -478,7 +392,7 @@ def records_jsonpath(cls): # noqa: N805 ], ) def test_next_page_token_jsonpath( - tap: SimpleTestTap, + tap: Tap, path: str, content: str, headers: dict, @@ -510,7 +424,7 @@ def test_cached_jsonpath(): assert recompiled is compiled -def test_sync_costs_calculation(tap: SimpleTestTap, caplog): +def test_sync_costs_calculation(tap: Tap, caplog): """Test sync costs are added up correctly.""" fake_request = requests.PreparedRequest() fake_response = requests.Response() @@ -595,7 +509,7 @@ def calculate_test_cost( ), ], ) -def test_stream_class_selection(input_catalog, selection): +def test_stream_class_selection(tap_class, input_catalog, selection): """Test stream class selection.""" class SelectedStream(RESTStream): @@ -607,11 +521,12 @@ class UnselectedStream(SelectedStream): name = "unselected_stream" selected_by_default = False - class MyTap(SimpleTestTap): + class MyTap(tap_class): def discover_streams(self): return [SelectedStream(self), UnselectedStream(self)] # Check that the selected stream is selected - tap = MyTap(config=None, catalog=input_catalog) - for stream in selection: - assert tap.streams[stream].selected is selection[stream] + tap = MyTap(config=None, catalog=input_catalog, validate_config=False) + assert all( + tap.streams[stream].selected is selection[stream] for stream in selection + ) diff --git a/tests/core/test_tap_class.py b/tests/core/test_tap_class.py new file mode 100644 index 000000000..93015fbb1 --- /dev/null +++ b/tests/core/test_tap_class.py @@ -0,0 +1,92 @@ +from __future__ import annotations + +import json +import typing as t +from contextlib import nullcontext + +import pytest +from click.testing import CliRunner + +from singer_sdk.exceptions import ConfigValidationError + +if t.TYPE_CHECKING: + from singer_sdk import Tap + + +@pytest.mark.parametrize( + "config_dict,expectation,errors", + [ + pytest.param( + {}, + pytest.raises(ConfigValidationError, match="Config validation failed"), + ["'username' is a required property", "'password' is a required property"], + id="missing_username_and_password", + ), + pytest.param( + {"username": "utest"}, + pytest.raises(ConfigValidationError, match="Config validation failed"), + ["'password' is a required property"], + id="missing_password", + ), + pytest.param( + {"username": "utest", "password": "ptest", "extra": "not valid"}, + pytest.raises(ConfigValidationError, match="Config validation failed"), + ["Additional properties are not allowed ('extra' was unexpected)"], + id="extra_property", + ), + pytest.param( + {"username": "utest", "password": "ptest"}, + nullcontext(), + [], + id="valid_config", + ), + ], +) +def test_config_errors( + tap_class: type[Tap], + config_dict: dict, + expectation, + errors: list[str], +): + with expectation as exc: + tap_class(config=config_dict, validate_config=True) + + if isinstance(exc, pytest.ExceptionInfo): + assert exc.value.errors == errors + + +def test_cli(tap_class: type[Tap]): + """Test the CLI.""" + runner = CliRunner(mix_stderr=False) + result = runner.invoke(tap_class.cli, ["--help"]) + assert result.exit_code == 0 + assert "Show this message and exit." in result.output + + +def test_cli_config_validation(tap_class: type[Tap], tmp_path): + """Test the CLI config validation.""" + runner = CliRunner(mix_stderr=False) + config_path = tmp_path / "config.json" + config_path.write_text(json.dumps({})) + result = runner.invoke(tap_class.cli, ["--config", str(config_path)]) + assert result.exit_code == 1 + assert not result.stdout + assert "'username' is a required property" in result.stderr + assert "'password' is a required property" in result.stderr + + +def test_cli_discover(tap_class: type[Tap], tmp_path): + """Test the CLI discover command.""" + runner = CliRunner(mix_stderr=False) + config_path = tmp_path / "config.json" + config_path.write_text(json.dumps({})) + result = runner.invoke( + tap_class.cli, + [ + "--config", + str(config_path), + "--discover", + ], + ) + assert result.exit_code == 0 + assert "streams" in json.loads(result.stdout) diff --git a/tests/core/test_target_class.py b/tests/core/test_target_class.py new file mode 100644 index 000000000..f84ae1dae --- /dev/null +++ b/tests/core/test_target_class.py @@ -0,0 +1,54 @@ +from __future__ import annotations + +import json +from contextlib import nullcontext + +import pytest +from click.testing import CliRunner + +from samples.sample_target_sqlite import SQLiteTarget +from singer_sdk.exceptions import ConfigValidationError + + +@pytest.mark.parametrize( + "config_dict,expectation,errors", + [ + pytest.param( + {}, + pytest.raises(ConfigValidationError, match="Config validation failed"), + ["'path_to_db' is a required property"], + id="missing_path_to_db", + ), + pytest.param( + {"path_to_db": "sqlite://test.db"}, + nullcontext(), + [], + id="valid_config", + ), + ], +) +def test_config_errors(config_dict: dict, expectation, errors: list[str]): + with expectation as exc: + SQLiteTarget(config=config_dict, validate_config=True) + + if isinstance(exc, pytest.ExceptionInfo): + assert exc.value.errors == errors + + +def test_cli(): + """Test the CLI.""" + runner = CliRunner(mix_stderr=False) + result = runner.invoke(SQLiteTarget.cli, ["--help"]) + assert result.exit_code == 0 + assert "Show this message and exit." in result.output + + +def test_cli_config_validation(tmp_path): + """Test the CLI config validation.""" + runner = CliRunner(mix_stderr=False) + config_path = tmp_path / "config.json" + config_path.write_text(json.dumps({})) + result = runner.invoke(SQLiteTarget.cli, ["--config", str(config_path)]) + assert result.exit_code == 1 + assert not result.stdout + assert "'path_to_db' is a required property" in result.stderr diff --git a/tests/samples/test_target_sqlite.py b/tests/samples/test_target_sqlite.py index 727b760ba..a66805a09 100644 --- a/tests/samples/test_target_sqlite.py +++ b/tests/samples/test_target_sqlite.py @@ -36,7 +36,7 @@ def path_to_target_db(tmp_path: Path) -> Path: @pytest.fixture -def sqlite_target_test_config(path_to_target_db: str) -> dict: +def sqlite_target_test_config(path_to_target_db: Path) -> dict: """Get configuration dictionary for target-csv.""" return {"path_to_db": str(path_to_target_db)}