Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: DBTP-1629 decoupling with duck typing #713

Open
wants to merge 17 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 12 additions & 14 deletions dbt_platform_helper/domain/config_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@

import boto3

import dbt_platform_helper.domain.versions as versions
from dbt_platform_helper.platform_exception import PlatformException
from dbt_platform_helper.providers.aws.opensearch import OpensearchProvider
from dbt_platform_helper.providers.aws.redis import RedisProvider
from dbt_platform_helper.providers.io import ClickIOProvider
from dbt_platform_helper.providers.opensearch import OpensearchProvider
from dbt_platform_helper.providers.redis import RedisProvider


class ConfigValidatorError(PlatformException):
Expand All @@ -32,7 +33,7 @@ def run_validations(self, config: dict):
validation(config)

def _validate_extension_supported_versions(
self, config, extension_type, version_key, get_supported_versions
self, config, aws_provider, extension_type, version_key
):
extensions = config.get("extensions", {})
if not extensions:
Expand All @@ -44,7 +45,8 @@ def _validate_extension_supported_versions(
if extension.get("type") == extension_type
]

supported_extension_versions = get_supported_versions()
# In this format so it can be monkey patched initially via mock_get_aws_supported_versions fixture
supported_extension_versions = versions.get_supported_aws_versions(aws_provider)
extensions_with_invalid_version = []

for extension in extensions_for_type:
Expand Down Expand Up @@ -74,21 +76,17 @@ def _validate_extension_supported_versions(
def validate_supported_redis_versions(self, config):
return self._validate_extension_supported_versions(
config=config,
extension_type="redis",
version_key="engine",
get_supported_versions=RedisProvider(
boto3.client("elasticache")
).get_supported_redis_versions,
aws_provider=RedisProvider(boto3.client("elasticache")),
extension_type="redis", # TODO this is information which can live in the RedisProvider
version_key="engine", # TODO this is information which can live in the RedisProvider
)

def validate_supported_opensearch_versions(self, config):
return self._validate_extension_supported_versions(
config=config,
extension_type="opensearch",
version_key="engine",
get_supported_versions=OpensearchProvider(
boto3.client("opensearch")
).get_supported_opensearch_versions,
aws_provider=OpensearchProvider(boto3.client("opensearch")),
extension_type="opensearch", # TODO this is information which can live in the OpensearchProvider
version_key="engine", # TODO this is information which can live in the OpensearchProvider
)

def validate_environment_pipelines(self, config):
Expand Down
31 changes: 31 additions & 0 deletions dbt_platform_helper/domain/versions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from typing import List

from dbt_platform_helper.providers.aws.interfaces import GetReferenceProtocol
from dbt_platform_helper.providers.aws.interfaces import GetVersionsProtocol
from dbt_platform_helper.providers.cache import CacheProvider


class AwsGetVersionProtocol(GetReferenceProtocol, GetVersionsProtocol):
pass


# TODO this will be set up within the caching provider using the stragegy pattern
def get_supported_aws_versions(
client_provider: AwsGetVersionProtocol,
cache_provider=CacheProvider(),
) -> List[str]:
"""
For a given AWS client provider get the supported versions if the operation
is supported.

The cache value is retrieved if it exists.
"""
supported_versions = []
aws_reference = client_provider.get_reference()
if cache_provider.cache_refresh_required(aws_reference):
supported_versions = client_provider.get_supported_versions()
cache_provider.update_cache(aws_reference, supported_versions)
else:
supported_versions = cache_provider.read_supported_versions_from_cache(aws_reference)

return supported_versions
9 changes: 9 additions & 0 deletions dbt_platform_helper/providers/aws/interfaces.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from typing import Protocol


class GetVersionsProtocol(Protocol):
def get_supported_versions(self) -> list[str]: ...


class GetReferenceProtocol(Protocol):
def get_reference(self) -> str: ...
25 changes: 25 additions & 0 deletions dbt_platform_helper/providers/aws/opensearch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import boto3


class OpensearchProvider:

def __init__(self, client: boto3.client):
self.client = client
# TODO extract engine so you could swap between opensearch and elastic in the same provider
self.engine = "OpenSearch"

def get_reference(self) -> str:
return self.engine.lower()

def get_supported_versions(self) -> list[str]:
response = self.client.list_versions()
all_versions = response["Versions"]

opensearch_versions = [
version for version in all_versions if version.startswith(f"{self.engine}_")
]
supported_versions = [
version.removeprefix(f"{self.engine}_") for version in opensearch_versions
]

return supported_versions
21 changes: 21 additions & 0 deletions dbt_platform_helper/providers/aws/redis.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import boto3


class RedisProvider:

def __init__(self, client: boto3.client):
self.client = client
self.engine = "redis"

def get_reference(self) -> str:
return self.engine.lower()

def get_supported_versions(self) -> list[str]:
supported_versions_response = self.client.describe_cache_engine_versions(Engine=self.engine)

supported_versions = [
version["EngineVersion"]
for version in supported_versions_response["CacheEngineVersions"]
]

return supported_versions
2 changes: 1 addition & 1 deletion dbt_platform_helper/providers/copilot.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from botocore.exceptions import ClientError

from dbt_platform_helper.constants import CONDUIT_DOCKER_IMAGE_LOCATION
from dbt_platform_helper.providers.aws import CreateTaskTimeoutException
from dbt_platform_helper.providers.aws.exceptions import CreateTaskTimeoutException
from dbt_platform_helper.providers.secrets import Secrets
from dbt_platform_helper.utils.application import Application
from dbt_platform_helper.utils.messages import abort_with_error
Expand Down
36 changes: 0 additions & 36 deletions dbt_platform_helper/providers/opensearch.py

This file was deleted.

34 changes: 0 additions & 34 deletions dbt_platform_helper/providers/redis.py

This file was deleted.

10 changes: 6 additions & 4 deletions dbt_platform_helper/utils/aws.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,12 @@

from dbt_platform_helper.constants import REFRESH_TOKEN_MESSAGE
from dbt_platform_helper.platform_exception import PlatformException
from dbt_platform_helper.providers.aws import CopilotCodebaseNotFoundException
from dbt_platform_helper.providers.aws import ImageNotFoundException
from dbt_platform_helper.providers.aws import LogGroupNotFoundException
from dbt_platform_helper.providers.aws import RepositoryNotFoundException
from dbt_platform_helper.providers.aws.exceptions import (
CopilotCodebaseNotFoundException,
)
from dbt_platform_helper.providers.aws.exceptions import ImageNotFoundException
from dbt_platform_helper.providers.aws.exceptions import LogGroupNotFoundException
from dbt_platform_helper.providers.aws.exceptions import RepositoryNotFoundException
from dbt_platform_helper.providers.validation import ValidationException

SSM_BASE_PATH = "/copilot/{app}/{env}/secrets/"
Expand Down
3 changes: 1 addition & 2 deletions platform-config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ environment_pipelines:
slack_channel: /codebuild/notification_channel
trigger_on_push: false
versions:
invalid-key: 1.2.3
platform-helper: main
environments:
'*':
Expand All @@ -69,8 +70,6 @@ environments:
id: '6677889900'
name: non-prod-dns-acc
requires_approval: false
versions:
invalid-key: 1.2.3
vpc: non-prod-vpc
dev: null
hotfix:
Expand Down
21 changes: 4 additions & 17 deletions tests/platform_helper/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,8 @@
from moto import mock_aws
from moto.ec2 import utils as ec2_utils

import dbt_platform_helper.domain.versions as versions
from dbt_platform_helper.constants import PLATFORM_CONFIG_FILE
from dbt_platform_helper.providers.opensearch import OpensearchProvider
from dbt_platform_helper.providers.redis import RedisProvider
from dbt_platform_helper.utils.aws import AWS_SESSION_CACHE
from dbt_platform_helper.utils.versioning import PlatformHelperVersions

Expand Down Expand Up @@ -802,23 +801,11 @@ def create_invalid_platform_config_file(fakefs):

# TODO - stop gap until validation.py is refactored into a class, then it will be an easier job of just passing in a mock_redis_provider into the constructor for the config_provider. For now autouse is needed.
@pytest.fixture(autouse=True)
def mock_get_supported_opensearch_versions(request, monkeypatch):
if "skip_opensearch_fixture" in request.keywords:
return

def mock_return_value(self):
return ["1.0", "1.1", "1.2"]

monkeypatch.setattr(OpensearchProvider, "get_supported_opensearch_versions", mock_return_value)


# TODO - stop gap until validation.py is refactored into a class, then it will be an easier job of just passing in a mock_redis_provider into the constructor for the config_provider. For now autouse is needed.
@pytest.fixture(autouse=True)
def mock_get_supported_redis_versions(request, monkeypatch):
if "skip_redis_fixture" in request.keywords:
def mock_get_aws_supported_versions(request, monkeypatch):
if "skip_supported_versions_fixture" in request.keywords:
return

def mock_return_value(self):
return ["6.2", "7.0", "7.1"]

monkeypatch.setattr(RedisProvider, "get_supported_redis_versions", mock_return_value)
monkeypatch.setattr(versions, "get_supported_aws_versions", mock_return_value)
4 changes: 2 additions & 2 deletions tests/platform_helper/domain/test_codebase.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
from dbt_platform_helper.domain.codebase import ApplicationEnvironmentNotFoundException
from dbt_platform_helper.domain.codebase import Codebase
from dbt_platform_helper.domain.codebase import NotInCodeBaseRepositoryException
from dbt_platform_helper.providers.aws import ImageNotFoundException
from dbt_platform_helper.providers.aws import RepositoryNotFoundException
from dbt_platform_helper.providers.aws.exceptions import ImageNotFoundException
from dbt_platform_helper.providers.aws.exceptions import RepositoryNotFoundException
from dbt_platform_helper.utils.application import ApplicationNotFoundException
from dbt_platform_helper.utils.application import Environment
from dbt_platform_helper.utils.git import CommitNotFoundException
Expand Down
2 changes: 1 addition & 1 deletion tests/platform_helper/domain/test_conduit.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import pytest

from dbt_platform_helper.domain.conduit import Conduit
from dbt_platform_helper.providers.aws import CreateTaskTimeoutException
from dbt_platform_helper.providers.aws.exceptions import CreateTaskTimeoutException
from dbt_platform_helper.providers.ecs import ECSAgentNotRunningException
from dbt_platform_helper.providers.ecs import NoClusterException
from dbt_platform_helper.providers.secrets import AddonNotFoundException
Expand Down
50 changes: 50 additions & 0 deletions tests/platform_helper/domain/test_get_supported_aws_versions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from unittest.mock import MagicMock

from dbt_platform_helper.domain.versions import get_supported_aws_versions


def test_get_supported_versions_cache_refresh():
mock_cache_provider = MagicMock()
mock_aws_provider = MagicMock()
setattr(mock_aws_provider, "get_reference", MagicMock(return_value="doesnt-matter"))
setattr(
mock_aws_provider,
"get_supported_versions",
MagicMock(return_value=["doesnt", "matter"]),
)
mock_cache_provider.cache_refresh_required.return_value = True

versions = get_supported_aws_versions(mock_aws_provider, mock_cache_provider)

mock_aws_provider.get_reference.assert_called()
mock_aws_provider.get_supported_versions.assert_called()
mock_cache_provider.update_cache.assert_called_with("doesnt-matter", ["doesnt", "matter"])
mock_cache_provider.read_supported_versions_from_cache.assert_not_called()

assert versions == ["doesnt", "matter"]


def test_get_supported_versions_no_cache_refresh():
mock_cache_provider = MagicMock()
mock_aws_provider = MagicMock()
setattr(mock_aws_provider, "get_reference", MagicMock(return_value="doesnt-matter"))
setattr(
mock_aws_provider,
"get_supported_versions",
MagicMock(return_value=["doesnt", "matter"]),
)
mock_cache_provider.cache_refresh_required.return_value = False
mock_cache_provider.read_supported_versions_from_cache.return_value = [
"cache",
"doesnt",
"matter",
]

versions = get_supported_aws_versions(mock_aws_provider, mock_cache_provider)

mock_aws_provider.get_reference.assert_called()
mock_aws_provider.get_supported_versions.assert_not_called()
mock_cache_provider.update_cache.assert_not_called()
mock_cache_provider.read_supported_versions_from_cache.assert_called_with("doesnt-matter")

assert versions == ["cache", "doesnt", "matter"]
Loading
Loading