Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Fix gcp credentials with dbt #44

Merged
merged 8 commits into from
Sep 6, 2022
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Fixed

- `ValidationError` using `GcpCredentials.service_account_info` in `prefect-dbt` - [#44](https://github.com/PrefectHQ/prefect-gcp/pull/44)

### Security

## 0.1.3
Expand Down
33 changes: 15 additions & 18 deletions prefect_gcp/credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import functools
import os
from pathlib import Path
from typing import Dict, Optional
from typing import Dict, Optional, Union

from google.oauth2.service_account import Credentials
from pydantic import Json
Expand Down Expand Up @@ -66,9 +66,9 @@ class GcpCredentials(Block):
[Authentication docs](https://cloud.google.com/docs/authentication/production)
for more info about the possible credential configurations.

Args:
Attributes:
service_account_file: Path to the service account JSON keyfile.
service_account_info: The contents of the keyfile as a JSON string / dictionary.
service_account_info: The contents of the keyfile as a dict or JSON string.
project: Name of the project to use.

Example:
Expand All @@ -83,7 +83,7 @@ class GcpCredentials(Block):
_block_type_name = "GCP Credentials"

service_account_file: Optional[Path] = None
service_account_info: Optional[Json] = None
service_account_info: Optional[Union[Dict[str, str], Json]] = None
project: Optional[str] = None

@staticmethod
Expand Down Expand Up @@ -139,26 +139,25 @@ def example_get_client_flow():
example_get_client_flow()
```

Gets a GCP Cloud Storage client from a JSON str.
Gets a GCP Cloud Storage client from a dictionary.
```python
import json
from prefect import flow
from prefect_gcp.credentials import GcpCredentials

@flow()
def example_get_client_flow():
service_account_info = json.dumps({
service_account_info = {
"type": "service_account",
"project_id": "project_id",
"private_key_id": "private_key_id",
"private_key": private_key",
"private_key": "private_key",
"client_email": "client_email",
"client_id": "client_id",
"auth_uri": "auth_uri",
"token_uri": "token_uri",
"auth_provider_x509_cert_url": "auth_provider_x509_cert_url",
"client_x509_cert_url": "client_x509_cert_url"
})
}
client = GcpCredentials(
service_account_info=service_account_info
).get_cloud_storage_client()
Expand Down Expand Up @@ -202,29 +201,28 @@ def example_get_client_flow():
example_get_client_flow()
```

Gets a GCP BigQuery client from a JSON str.
Gets a GCP BigQuery client from a dictionary.
```python
import json
from prefect import flow
from prefect_gcp.credentials import GcpCredentials

@flow()
def example_get_client_flow():
service_account_info = json.dumps({
service_account_info = {
"type": "service_account",
"project_id": "project_id",
"private_key_id": "private_key_id",
"private_key": private_key",
"private_key": "private_key",
"client_email": "client_email",
"client_id": "client_id",
"auth_uri": "auth_uri",
"token_uri": "token_uri",
"auth_provider_x509_cert_url": "auth_provider_x509_cert_url",
"client_x509_cert_url": "client_x509_cert_url"
})
}
client = GcpCredentials(
service_account_info=service_account_info
).get_bigquery_client(json)
).get_bigquery_client()

example_get_client_flow()
```
Expand Down Expand Up @@ -264,15 +262,14 @@ def example_get_client_flow():
example_get_client_flow()
```

Gets a GCP Cloud Storage client from a JSON str.
Gets a GCP Cloud Storage client from a dictionary.
```python
import json
from prefect import flow
from prefect_gcp.credentials import GcpCredentials

@flow()
def example_get_client_flow():
service_account_info = json.dumps({
service_account_info = {
"type": "service_account",
"project_id": "project_id",
"private_key_id": "private_key_id",
Expand Down
87 changes: 78 additions & 9 deletions tests/test_credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
from pathlib import Path

import pytest
from prefect import flow
from prefect import flow, task
from prefect.blocks.core import Block

from prefect_gcp import GcpCredentials

Expand All @@ -13,8 +14,20 @@


@pytest.fixture()
def service_account_info():
return '{"key": "abc", "pass": "pass"}'
def service_account_info_dict():
_service_account_info = {"key": "abc", "pass": "pass"}
return _service_account_info


@pytest.fixture()
def service_account_info_json(service_account_info_dict):
_service_account_info = json.dumps(service_account_info_dict)
return _service_account_info


@pytest.fixture(params=["service_account_info_dict", "service_account_info_json"])
def service_account_info(request):
return request.getfixturevalue(request.param)


@pytest.mark.parametrize("service_account_file", SERVICE_ACCOUNT_FILES)
Expand All @@ -28,12 +41,12 @@ def test_get_credentials_from_service_account_file(


def test_get_credentials_from_service_account_info(
service_account_info, oauth2_credentials
service_account_info_dict, oauth2_credentials
):
credentials = GcpCredentials._get_credentials_from_service_account(
service_account_info=service_account_info
service_account_info=service_account_info_dict
)
assert credentials == service_account_info
assert credentials == service_account_info_dict


def test_get_credentials_from_service_account_none(oauth2_credentials):
Expand All @@ -48,12 +61,12 @@ def test_get_credentials_from_service_account_file_error(oauth2_credentials):


def test_get_credentials_from_service_account_both_error(
service_account_info, oauth2_credentials
service_account_info_dict, oauth2_credentials
):
with pytest.raises(ValueError):
GcpCredentials._get_credentials_from_service_account(
service_account_file=SERVICE_ACCOUNT_FILES[0],
service_account_info=service_account_info,
service_account_info=service_account_info_dict,
)


Expand All @@ -64,12 +77,17 @@ def test_get_cloud_storage_client(
@flow
def test_flow():
project = "test_project"
print(service_account_info)
ahuang11 marked this conversation as resolved.
Show resolved Hide resolved
credentials = GcpCredentials(
service_account_info=service_account_info,
project=project,
)
client = credentials.get_cloud_storage_client(project=override_project)
assert client.credentials == json.loads(service_account_info)
if isinstance(service_account_info, str):
expected = json.loads(service_account_info)
else:
expected = service_account_info
assert client.credentials == expected

if override_project is None:
assert client.project == project
Expand All @@ -78,3 +96,54 @@ def test_flow():
return True

test_flow()


class MockTargetConfigs(Block):
credentials: GcpCredentials

def get_configs(self):
"""
Returns the dbt configs, likely used eventually for writing to profiles.yml.
Returns:
A configs JSON.
"""
return self.credentials.dict()


class MockCliProfile(Block):
target_configs: MockTargetConfigs

def get_profile(self):
profile = {
"name": {
"outputs": {"target": self.target_configs.get_configs()},
},
}
return profile


def test_credentials_is_able_to_serialize_back(service_account_info):
@task
def test_task(mock_cli_profile):
return mock_cli_profile.get_profile()

@flow
def test_flow():
gcp_credentials = GcpCredentials(service_account_info=service_account_info)
mock_target_configs = MockTargetConfigs(credentials=gcp_credentials)
mock_cli_profile = MockCliProfile(target_configs=mock_target_configs)
task_result = test_task(mock_cli_profile)
return task_result

expected = {
"name": {
"outputs": {
"target": {
"project": None,
"service_account_file": None,
"service_account_info": {"key": "abc", "pass": "pass"},
}
}
}
}
assert test_flow() == expected