Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add AWS SQS publisher and example job file. #431

Closed
wants to merge 10 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
175 changes: 175 additions & 0 deletions databuilder/publisher/aws_sqs_csv_publisher.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
# Copyright Contributors to the Amundsen project.
# SPDX-License-Identifier: Apache-2.0

import csv
import ctypes
import json
import logging
import time
from io import open
from os import listdir
from os.path import isfile, join
from typing import List

import boto3
import pandas
from botocore.config import Config
from pyhocon import ConfigFactory, ConfigTree

from databuilder.publisher.base_publisher import Publisher

# Setting field_size_limit to solve the error below
# _csv.Error: field larger than field limit (131072)
# https://stackoverflow.com/a/54517228/5972935
csv.field_size_limit(int(ctypes.c_ulong(-1).value // 2))

# Config keys
# A directory that contains CSV files for nodes
NODE_FILES_DIR = 'node_files_directory'
# A directory that contains CSV files for relationships
RELATION_FILES_DIR = 'relation_files_directory'

# AWS SQS configs
# AWS SQS region
AWS_SQS_REGION = 'aws_sqs_region'
# AWS SQS url to send a message
AWS_SQS_URL = 'aws_sqs_url'
# AWS SQS message group id
AWS_SQS_MESSAGE_GROUP_ID = 'aws_sqs_message_group_id'
# credential configuration of AWS SQS
AWS_SQS_ACCESS_KEY_ID = 'aws_sqs_access_key_id'
AWS_SQS_SECRET_ACCESS_KEY = 'aws_sqs_secret_access_key'

# This will be used to provide unique tag to the node and relationship
JOB_PUBLISH_TAG = 'job_publish_tag'

# CSV HEADER
# A header with this suffix will be pass to Neo4j statement without quote
UNQUOTED_SUFFIX = ':UNQUOTED'
# A header for Node label
NODE_LABEL_KEY = 'LABEL'
# A header for Node key
NODE_KEY_KEY = 'KEY'
# Required columns for Node
NODE_REQUIRED_KEYS = {NODE_LABEL_KEY, NODE_KEY_KEY}

DEFAULT_CONFIG = ConfigFactory.from_dict({AWS_SQS_MESSAGE_GROUP_ID: 'metadata'})

LOGGER = logging.getLogger(__name__)


class AWSSQSCsvPublisher(Publisher):
"""
A Publisher takes two folders for input and publishes it as message to AWS SQS.
One folder will contain CSV file(s) for Node where the other folder will contain CSV file(s) for Relationship.
If the target AWS SQS Queue does not use content based deduplication, Message ID should be defined.
Single message size is limited to 250 KB. if one message size is larger than that, error logs will be printed.
"""

def __init__(self) -> None:
super(AWSSQSCsvPublisher, self).__init__()

def init(self, conf: ConfigTree) -> None:
conf = conf.with_fallback(DEFAULT_CONFIG)

self._node_files = self._list_files(conf, NODE_FILES_DIR)
self._node_files_iter = iter(self._node_files)

self._relation_files = self._list_files(conf, RELATION_FILES_DIR)
self._relation_files_iter = iter(self._relation_files)

# Initialize AWS SQS client
self.client = self._get_client(conf=conf)
self.aws_sqs_url = conf.get_string(AWS_SQS_URL)
self.message_group_id = conf.get_string(AWS_SQS_MESSAGE_GROUP_ID)

LOGGER.info('Publishing Node csv files {}, and Relation CSV files {}'
.format(self._node_files, self._relation_files))

def _list_files(self, conf: ConfigTree, path_key: str) -> List[str]:
"""
List files from directory
:param conf:
:param path_key:
:return: List of file paths
"""
if path_key not in conf:
return []

path = conf.get_string(path_key)
return [join(path, f) for f in listdir(path) if isfile(join(path, f))]

def publish_impl(self) -> None: # noqa: C901
"""
Publishes Nodes first and then Relations
:return:
"""

start = time.time()

LOGGER.info('Publishing Node files: {}'.format(self._node_files))
nodes = []
relations = []

try:
while True:
try:
node_file = next(self._node_files_iter)
nodes.extend(self._publish_record(node_file))
except StopIteration:
break

LOGGER.info('Publishing Relationship files: {}'.format(self._relation_files))
while True:
try:
relation_file = next(self._relation_files_iter)
relations.extend(self._publish_record(relation_file))
except StopIteration:
break

message_body = {
'nodes': nodes,
'relations': relations
}

LOGGER.info('Publish nodes and relationships to Queue {}'.format(self.aws_sqs_url))

self.client.send_message(
QueueUrl=self.aws_sqs_url,
MessageBody=json.dumps(message_body),
MessageGroupId=self.message_group_id
)

LOGGER.info('Successfully published. Elapsed: {} seconds'.format(time.time() - start))
except Exception as e:
LOGGER.exception('Failed to publish.')
raise e

def get_scope(self) -> str:
return 'publisher.awssqs'

def _publish_record(self, csv_file: str) -> list:
"""
Iterate over the csv records of a file, each csv record transform to dict and will be added to list.
All nodes and relations (in csv, each one is record) should have a unique key
:param csv_file:
:return:
"""
ret = []

with open(csv_file, 'r', encoding='utf8') as record_csv:
for record in pandas.read_csv(record_csv, na_filter=False).to_dict(orient="records"):
ret.append(record)

return ret

def _get_client(self, conf: ConfigTree) -> boto3.client:
"""
Create a client object to access AWS SQS
:return:
"""
return boto3.client('sqs',
aws_access_key_id=conf.get_string(AWS_SQS_ACCESS_KEY_ID),
aws_secret_access_key=conf.get_string(AWS_SQS_SECRET_ACCESS_KEY),
config=Config(region_name=conf.get_string(AWS_SQS_REGION))
)
81 changes: 81 additions & 0 deletions example/scripts/sample_mysql_to_aws_sqs_publisher.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# Copyright Contributors to the Amundsen project.
# SPDX-License-Identifier: Apache-2.0

import os
import textwrap

from pyhocon import ConfigFactory

from databuilder.extractor.mysql_metadata_extractor import MysqlMetadataExtractor
from databuilder.extractor.sql_alchemy_extractor import SQLAlchemyExtractor
from databuilder.job.job import DefaultJob
from databuilder.loader.file_system_neo4j_csv_loader import FsNeo4jCSVLoader
from databuilder.publisher import aws_sqs_csv_publisher
from databuilder.publisher.aws_sqs_csv_publisher import AWSSQSCsvPublisher
from databuilder.task.task import DefaultTask

# TODO: AWS SQS url, region and credentials need to change
AWS_SQS_REGION = os.getenv('AWS_SQS_REGION', 'ap-northeast-2')
AWS_SQS_URL = os.getenv('AWS_SQS_URL', 'https://sqs.ap-northeast-2.amazonaws.com')
AWS_SQS_ACCESS_KEY_ID = os.getenv('AWS_SQS_ACCESS_KEY_ID', '')
AWS_SQS_SECRET_ACCESS_KEY = os.getenv('AWS_SQS_SECRET_ACCESS_KEY', '')

# TODO: connection string needs to change
# Source DB configuration
DATABASE_HOST = os.getenv('DATABASE_HOST', 'localhost')
DATABASE_PORT = os.getenv('DATABASE_PORT', '3306')
DATABASE_USER = os.getenv('DATABASE_USER', 'root')
DATABASE_PASSWORD = os.getenv('DATABASE_PASSWORD', 'root')
DATABASE_DB_NAME = os.getenv('DATABASE_DB_NAME', 'mysql')

MYSQL_CONN_STRING = \
f'mysql://{DATABASE_USER}:{DATABASE_PASSWORD}@{DATABASE_HOST}:{DATABASE_PORT}/{DATABASE_DB_NAME}'


def run_mysql_job() -> DefaultJob:
where_clause_suffix = textwrap.dedent("""
where c.table_schema = 'mysql'
""")

tmp_folder = '/var/tmp/amundsen/table_metadata'
node_files_folder = '{tmp_folder}/nodes/'.format(tmp_folder=tmp_folder)
relationship_files_folder = '{tmp_folder}/relationships/'.format(tmp_folder=tmp_folder)

job_config = ConfigFactory.from_dict({
'extractor.mysql_metadata.{}'.format(MysqlMetadataExtractor.WHERE_CLAUSE_SUFFIX_KEY):
where_clause_suffix,
'extractor.mysql_metadata.{}'.format(MysqlMetadataExtractor.USE_CATALOG_AS_CLUSTER_NAME):
True,
'extractor.mysql_metadata.extractor.sqlalchemy.{}'.format(SQLAlchemyExtractor.CONN_STRING):
MYSQL_CONN_STRING,
'loader.filesystem_csv_neo4j.{}'.format(FsNeo4jCSVLoader.NODE_DIR_PATH):
node_files_folder,
'loader.filesystem_csv_neo4j.{}'.format(FsNeo4jCSVLoader.RELATION_DIR_PATH):
relationship_files_folder,
'publisher.awssqs.{}'.format(aws_sqs_csv_publisher.NODE_FILES_DIR):
node_files_folder,
'publisher.awssqs.{}'.format(aws_sqs_csv_publisher.RELATION_FILES_DIR):
relationship_files_folder,
'publisher.awssqs.{}'.format(aws_sqs_csv_publisher.AWS_SQS_REGION):
AWS_SQS_REGION,
'publisher.awssqs.{}'.format(aws_sqs_csv_publisher.AWS_SQS_URL):
AWS_SQS_URL,
'publisher.awssqs.{}'.format(aws_sqs_csv_publisher.AWS_SQS_ACCESS_KEY_ID):
AWS_SQS_ACCESS_KEY_ID,
'publisher.awssqs.{}'.format(aws_sqs_csv_publisher.AWS_SQS_SECRET_ACCESS_KEY):
AWS_SQS_SECRET_ACCESS_KEY,
'publisher.awssqs.{}'.format(aws_sqs_csv_publisher.JOB_PUBLISH_TAG):
'unique_tag', # should use unique tag here like {ds}
})
job = DefaultJob(conf=job_config,
task=DefaultTask(extractor=MysqlMetadataExtractor(), loader=FsNeo4jCSVLoader()),
publisher=AWSSQSCsvPublisher())
return job


if __name__ == "__main__":
# Uncomment next line to get INFO level logging
# logging.basicConfig(level=logging.INFO)

mysql_job = run_mysql_job()
mysql_job.launch()
9 changes: 7 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,12 @@
'pyatlasclient==1.1.2'
]

aws = [
'boto3>=1.10.1'
]

all_deps = requirements + kafka + cassandra + glue + snowflake + athena + \
bigquery + jsonpath + db2 + dremio + druid + spark + feast
bigquery + jsonpath + db2 + dremio + druid + spark + feast + aws

setup(
name='amundsen-databuilder',
Expand All @@ -97,7 +101,8 @@
'druid': druid,
'delta': spark,
'feast': feast,
'atlas': atlas
'atlas': atlas,
'aws': aws
},
classifiers=[
'Programming Language :: Python :: 3.6',
Expand Down
49 changes: 49 additions & 0 deletions tests/unit/publisher/test_aws_sqs_csv_publisher.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# Copyright Contributors to the Amundsen project.
# SPDX-License-Identifier: Apache-2.0

import logging
import os
import unittest
import uuid
import boto3

from mock import MagicMock, patch
from pyhocon import ConfigFactory

from databuilder.publisher import aws_sqs_csv_publisher
from databuilder.publisher.aws_sqs_csv_publisher import AWSSQSCsvPublisher

here = os.path.dirname(__file__)


class TestAWSSQSPublish(unittest.TestCase):

def setUp(self) -> None:
logging.basicConfig(level=logging.INFO)
self._resource_path = os.path.join(here, f'../resources/csv_publisher')

def test_publisher(self) -> None:
with patch.object(boto3, 'client') as mock_client, \
patch.object(AWSSQSCsvPublisher, '_publish_record') as mock_publish_record:

mock_send_message = MagicMock()
mock_client.return_value.send_message = mock_send_message

publisher = AWSSQSCsvPublisher()

conf = ConfigFactory.from_dict(
{aws_sqs_csv_publisher.NODE_FILES_DIR: f'{self._resource_path}/nodes',
aws_sqs_csv_publisher.RELATION_FILES_DIR: f'{self._resource_path}/relations',
aws_sqs_csv_publisher.AWS_SQS_REGION: 'aws_region',
aws_sqs_csv_publisher.AWS_SQS_URL: 'aws_sqs_url',
aws_sqs_csv_publisher.AWS_SQS_ACCESS_KEY_ID: 'aws_account_access_key_id',
aws_sqs_csv_publisher.AWS_SQS_SECRET_ACCESS_KEY: 'aws_account_secret_access_key',
aws_sqs_csv_publisher.JOB_PUBLISH_TAG: str(uuid.uuid4())}
)
publisher.init(conf)
publisher.publish()

# 2 node files and 1 relation file
self.assertEqual(mock_publish_record.call_count, 3)

self.assertEqual(mock_send_message.call_count, 1)