diff --git a/databuilder/publisher/aws_sqs_csv_publisher.py b/databuilder/publisher/aws_sqs_csv_publisher.py new file mode 100644 index 000000000..6df2d9348 --- /dev/null +++ b/databuilder/publisher/aws_sqs_csv_publisher.py @@ -0,0 +1,175 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import csv +import ctypes +import json +import logging +import time +from io import open +from os import listdir +from os.path import isfile, join +from typing import List + +import boto3 +import pandas +from botocore.config import Config +from pyhocon import ConfigFactory, ConfigTree + +from databuilder.publisher.base_publisher import Publisher + +# Setting field_size_limit to solve the error below +# _csv.Error: field larger than field limit (131072) +# https://stackoverflow.com/a/54517228/5972935 +csv.field_size_limit(int(ctypes.c_ulong(-1).value // 2)) + +# Config keys +# A directory that contains CSV files for nodes +NODE_FILES_DIR = 'node_files_directory' +# A directory that contains CSV files for relationships +RELATION_FILES_DIR = 'relation_files_directory' + +# AWS SQS configs +# AWS SQS region +AWS_SQS_REGION = 'aws_sqs_region' +# AWS SQS url to send a message +AWS_SQS_URL = 'aws_sqs_url' +# AWS SQS message group id +AWS_SQS_MESSAGE_GROUP_ID = 'aws_sqs_message_group_id' +# credential configuration of AWS SQS +AWS_SQS_ACCESS_KEY_ID = 'aws_sqs_access_key_id' +AWS_SQS_SECRET_ACCESS_KEY = 'aws_sqs_secret_access_key' + +# This will be used to provide unique tag to the node and relationship +JOB_PUBLISH_TAG = 'job_publish_tag' + +# CSV HEADER +# A header with this suffix will be pass to Neo4j statement without quote +UNQUOTED_SUFFIX = ':UNQUOTED' +# A header for Node label +NODE_LABEL_KEY = 'LABEL' +# A header for Node key +NODE_KEY_KEY = 'KEY' +# Required columns for Node +NODE_REQUIRED_KEYS = {NODE_LABEL_KEY, NODE_KEY_KEY} + +DEFAULT_CONFIG = ConfigFactory.from_dict({AWS_SQS_MESSAGE_GROUP_ID: 'metadata'}) + +LOGGER = logging.getLogger(__name__) + + +class AWSSQSCsvPublisher(Publisher): + """ + A Publisher takes two folders for input and publishes it as message to AWS SQS. + One folder will contain CSV file(s) for Node where the other folder will contain CSV file(s) for Relationship. + If the target AWS SQS Queue does not use content based deduplication, Message ID should be defined. + Single message size is limited to 250 KB. if one message size is larger than that, error logs will be printed. + """ + + def __init__(self) -> None: + super(AWSSQSCsvPublisher, self).__init__() + + def init(self, conf: ConfigTree) -> None: + conf = conf.with_fallback(DEFAULT_CONFIG) + + self._node_files = self._list_files(conf, NODE_FILES_DIR) + self._node_files_iter = iter(self._node_files) + + self._relation_files = self._list_files(conf, RELATION_FILES_DIR) + self._relation_files_iter = iter(self._relation_files) + + # Initialize AWS SQS client + self.client = self._get_client(conf=conf) + self.aws_sqs_url = conf.get_string(AWS_SQS_URL) + self.message_group_id = conf.get_string(AWS_SQS_MESSAGE_GROUP_ID) + + LOGGER.info('Publishing Node csv files {}, and Relation CSV files {}' + .format(self._node_files, self._relation_files)) + + def _list_files(self, conf: ConfigTree, path_key: str) -> List[str]: + """ + List files from directory + :param conf: + :param path_key: + :return: List of file paths + """ + if path_key not in conf: + return [] + + path = conf.get_string(path_key) + return [join(path, f) for f in listdir(path) if isfile(join(path, f))] + + def publish_impl(self) -> None: # noqa: C901 + """ + Publishes Nodes first and then Relations + :return: + """ + + start = time.time() + + LOGGER.info('Publishing Node files: {}'.format(self._node_files)) + nodes = [] + relations = [] + + try: + while True: + try: + node_file = next(self._node_files_iter) + nodes.extend(self._publish_record(node_file)) + except StopIteration: + break + + LOGGER.info('Publishing Relationship files: {}'.format(self._relation_files)) + while True: + try: + relation_file = next(self._relation_files_iter) + relations.extend(self._publish_record(relation_file)) + except StopIteration: + break + + message_body = { + 'nodes': nodes, + 'relations': relations + } + + LOGGER.info('Publish nodes and relationships to Queue {}'.format(self.aws_sqs_url)) + + self.client.send_message( + QueueUrl=self.aws_sqs_url, + MessageBody=json.dumps(message_body), + MessageGroupId=self.message_group_id + ) + + LOGGER.info('Successfully published. Elapsed: {} seconds'.format(time.time() - start)) + except Exception as e: + LOGGER.exception('Failed to publish.') + raise e + + def get_scope(self) -> str: + return 'publisher.awssqs' + + def _publish_record(self, csv_file: str) -> list: + """ + Iterate over the csv records of a file, each csv record transform to dict and will be added to list. + All nodes and relations (in csv, each one is record) should have a unique key + :param csv_file: + :return: + """ + ret = [] + + with open(csv_file, 'r', encoding='utf8') as record_csv: + for record in pandas.read_csv(record_csv, na_filter=False).to_dict(orient="records"): + ret.append(record) + + return ret + + def _get_client(self, conf: ConfigTree) -> boto3.client: + """ + Create a client object to access AWS SQS + :return: + """ + return boto3.client('sqs', + aws_access_key_id=conf.get_string(AWS_SQS_ACCESS_KEY_ID), + aws_secret_access_key=conf.get_string(AWS_SQS_SECRET_ACCESS_KEY), + config=Config(region_name=conf.get_string(AWS_SQS_REGION)) + ) diff --git a/example/scripts/sample_mysql_to_aws_sqs_publisher.py b/example/scripts/sample_mysql_to_aws_sqs_publisher.py new file mode 100644 index 000000000..a479d3b49 --- /dev/null +++ b/example/scripts/sample_mysql_to_aws_sqs_publisher.py @@ -0,0 +1,81 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import os +import textwrap + +from pyhocon import ConfigFactory + +from databuilder.extractor.mysql_metadata_extractor import MysqlMetadataExtractor +from databuilder.extractor.sql_alchemy_extractor import SQLAlchemyExtractor +from databuilder.job.job import DefaultJob +from databuilder.loader.file_system_neo4j_csv_loader import FsNeo4jCSVLoader +from databuilder.publisher import aws_sqs_csv_publisher +from databuilder.publisher.aws_sqs_csv_publisher import AWSSQSCsvPublisher +from databuilder.task.task import DefaultTask + +# TODO: AWS SQS url, region and credentials need to change +AWS_SQS_REGION = os.getenv('AWS_SQS_REGION', 'ap-northeast-2') +AWS_SQS_URL = os.getenv('AWS_SQS_URL', 'https://sqs.ap-northeast-2.amazonaws.com') +AWS_SQS_ACCESS_KEY_ID = os.getenv('AWS_SQS_ACCESS_KEY_ID', '') +AWS_SQS_SECRET_ACCESS_KEY = os.getenv('AWS_SQS_SECRET_ACCESS_KEY', '') + +# TODO: connection string needs to change +# Source DB configuration +DATABASE_HOST = os.getenv('DATABASE_HOST', 'localhost') +DATABASE_PORT = os.getenv('DATABASE_PORT', '3306') +DATABASE_USER = os.getenv('DATABASE_USER', 'root') +DATABASE_PASSWORD = os.getenv('DATABASE_PASSWORD', 'root') +DATABASE_DB_NAME = os.getenv('DATABASE_DB_NAME', 'mysql') + +MYSQL_CONN_STRING = \ + f'mysql://{DATABASE_USER}:{DATABASE_PASSWORD}@{DATABASE_HOST}:{DATABASE_PORT}/{DATABASE_DB_NAME}' + + +def run_mysql_job() -> DefaultJob: + where_clause_suffix = textwrap.dedent(""" + where c.table_schema = 'mysql' + """) + + tmp_folder = '/var/tmp/amundsen/table_metadata' + node_files_folder = '{tmp_folder}/nodes/'.format(tmp_folder=tmp_folder) + relationship_files_folder = '{tmp_folder}/relationships/'.format(tmp_folder=tmp_folder) + + job_config = ConfigFactory.from_dict({ + 'extractor.mysql_metadata.{}'.format(MysqlMetadataExtractor.WHERE_CLAUSE_SUFFIX_KEY): + where_clause_suffix, + 'extractor.mysql_metadata.{}'.format(MysqlMetadataExtractor.USE_CATALOG_AS_CLUSTER_NAME): + True, + 'extractor.mysql_metadata.extractor.sqlalchemy.{}'.format(SQLAlchemyExtractor.CONN_STRING): + MYSQL_CONN_STRING, + 'loader.filesystem_csv_neo4j.{}'.format(FsNeo4jCSVLoader.NODE_DIR_PATH): + node_files_folder, + 'loader.filesystem_csv_neo4j.{}'.format(FsNeo4jCSVLoader.RELATION_DIR_PATH): + relationship_files_folder, + 'publisher.awssqs.{}'.format(aws_sqs_csv_publisher.NODE_FILES_DIR): + node_files_folder, + 'publisher.awssqs.{}'.format(aws_sqs_csv_publisher.RELATION_FILES_DIR): + relationship_files_folder, + 'publisher.awssqs.{}'.format(aws_sqs_csv_publisher.AWS_SQS_REGION): + AWS_SQS_REGION, + 'publisher.awssqs.{}'.format(aws_sqs_csv_publisher.AWS_SQS_URL): + AWS_SQS_URL, + 'publisher.awssqs.{}'.format(aws_sqs_csv_publisher.AWS_SQS_ACCESS_KEY_ID): + AWS_SQS_ACCESS_KEY_ID, + 'publisher.awssqs.{}'.format(aws_sqs_csv_publisher.AWS_SQS_SECRET_ACCESS_KEY): + AWS_SQS_SECRET_ACCESS_KEY, + 'publisher.awssqs.{}'.format(aws_sqs_csv_publisher.JOB_PUBLISH_TAG): + 'unique_tag', # should use unique tag here like {ds} + }) + job = DefaultJob(conf=job_config, + task=DefaultTask(extractor=MysqlMetadataExtractor(), loader=FsNeo4jCSVLoader()), + publisher=AWSSQSCsvPublisher()) + return job + + +if __name__ == "__main__": + # Uncomment next line to get INFO level logging + # logging.basicConfig(level=logging.INFO) + + mysql_job = run_mysql_job() + mysql_job.launch() diff --git a/setup.py b/setup.py index b7bde3eb8..f5dc59cfc 100644 --- a/setup.py +++ b/setup.py @@ -69,8 +69,12 @@ 'pyatlasclient==1.1.2' ] +aws = [ + 'boto3>=1.10.1' +] + all_deps = requirements + kafka + cassandra + glue + snowflake + athena + \ - bigquery + jsonpath + db2 + dremio + druid + spark + feast + bigquery + jsonpath + db2 + dremio + druid + spark + feast + aws setup( name='amundsen-databuilder', @@ -97,7 +101,8 @@ 'druid': druid, 'delta': spark, 'feast': feast, - 'atlas': atlas + 'atlas': atlas, + 'aws': aws }, classifiers=[ 'Programming Language :: Python :: 3.6', diff --git a/tests/unit/publisher/test_aws_sqs_csv_publisher.py b/tests/unit/publisher/test_aws_sqs_csv_publisher.py new file mode 100644 index 000000000..74ccc021f --- /dev/null +++ b/tests/unit/publisher/test_aws_sqs_csv_publisher.py @@ -0,0 +1,49 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import logging +import os +import unittest +import uuid +import boto3 + +from mock import MagicMock, patch +from pyhocon import ConfigFactory + +from databuilder.publisher import aws_sqs_csv_publisher +from databuilder.publisher.aws_sqs_csv_publisher import AWSSQSCsvPublisher + +here = os.path.dirname(__file__) + + +class TestAWSSQSPublish(unittest.TestCase): + + def setUp(self) -> None: + logging.basicConfig(level=logging.INFO) + self._resource_path = os.path.join(here, f'../resources/csv_publisher') + + def test_publisher(self) -> None: + with patch.object(boto3, 'client') as mock_client, \ + patch.object(AWSSQSCsvPublisher, '_publish_record') as mock_publish_record: + + mock_send_message = MagicMock() + mock_client.return_value.send_message = mock_send_message + + publisher = AWSSQSCsvPublisher() + + conf = ConfigFactory.from_dict( + {aws_sqs_csv_publisher.NODE_FILES_DIR: f'{self._resource_path}/nodes', + aws_sqs_csv_publisher.RELATION_FILES_DIR: f'{self._resource_path}/relations', + aws_sqs_csv_publisher.AWS_SQS_REGION: 'aws_region', + aws_sqs_csv_publisher.AWS_SQS_URL: 'aws_sqs_url', + aws_sqs_csv_publisher.AWS_SQS_ACCESS_KEY_ID: 'aws_account_access_key_id', + aws_sqs_csv_publisher.AWS_SQS_SECRET_ACCESS_KEY: 'aws_account_secret_access_key', + aws_sqs_csv_publisher.JOB_PUBLISH_TAG: str(uuid.uuid4())} + ) + publisher.init(conf) + publisher.publish() + + # 2 node files and 1 relation file + self.assertEqual(mock_publish_record.call_count, 3) + + self.assertEqual(mock_send_message.call_count, 1)