From 258ef9d1e10db32aa5eabf1147464d08904cdd4b Mon Sep 17 00:00:00 2001 From: harjeevan maan Date: Fri, 6 Dec 2024 16:50:39 +0530 Subject: [PATCH] Added the ArangoDBCollectionOperator that executes collection operations in a ArangoDB database (#44676) * - Added the ArangoDBCollectionOperator that executes collection operations in a ArangoDB database - Insert, replace, update and delete documents functionality included * Fixed a failing test in hooks --- .../providers/arangodb/hooks/arangodb.py | 65 +++++++++++++- .../providers/arangodb/operators/arangodb.py | 88 ++++++++++++++++++- .../tests/arangodb/hooks/test_arangodb.py | 4 + .../tests/arangodb/operators/test_arangodb.py | 20 ++++- 4 files changed, 172 insertions(+), 5 deletions(-) diff --git a/providers/src/airflow/providers/arangodb/hooks/arangodb.py b/providers/src/airflow/providers/arangodb/hooks/arangodb.py index e0625b8b9e1bf..ca2eb6ebdb03c 100644 --- a/providers/src/airflow/providers/arangodb/hooks/arangodb.py +++ b/providers/src/airflow/providers/arangodb/hooks/arangodb.py @@ -23,12 +23,18 @@ from typing import TYPE_CHECKING, Any from arango import AQLQueryExecuteError, ArangoClient as ArangoDBClient +from arango.cursor import Cursor +from arango.exceptions import ( + DocumentDeleteError, + DocumentInsertError, + DocumentReplaceError, + DocumentUpdateError, +) from airflow.exceptions import AirflowException from airflow.hooks.base import BaseHook if TYPE_CHECKING: - from arango.cursor import Cursor from arango.database import StandardDatabase from airflow.models import Connection @@ -101,8 +107,8 @@ def query(self, query, **kwargs) -> Cursor: try: if self.db_conn: result = self.db_conn.aql.execute(query, **kwargs) - if TYPE_CHECKING: - assert isinstance(result, Cursor) + if not isinstance(result, Cursor): + raise AirflowException("Failed to execute AQLQuery, expected result to be of type Cursor") return result else: raise AirflowException( @@ -113,12 +119,21 @@ def query(self, query, **kwargs) -> Cursor: def create_collection(self, name): if not self.db_conn.has_collection(name): + self.log.info("Collection '%s' does not exist. Creating a new collection.", name) self.db_conn.create_collection(name) return True else: self.log.info("Collection already exists: %s", name) return False + def delete_collection(self, name): + if self.db_conn.has_collection(name): + self.db_conn.delete_collection(name) + return True + else: + self.log.info("Collection does not exist: %s", name) + return False + def create_database(self, name): if not self.db_conn.has_database(name): self.db_conn.create_database(name) @@ -135,6 +150,50 @@ def create_graph(self, name): self.log.info("Graph already exists: %s", name) return False + def insert_documents(self, collection_name, documents): + if not self.db_conn.has_collection(collection_name): + self.create_collection(collection_name) + + try: + collection = self.db_conn.collection(collection_name) + collection.insert_many(documents, silent=True) + except DocumentInsertError as e: + self.log.error("Failed to insert documents: %s", str(e)) + raise + + def update_documents(self, collection_name, documents): + if not self.db_conn.has_collection(collection_name): + raise AirflowException(f"Collection does not exist: {collection_name}") + + try: + collection = self.db_conn.collection(collection_name) + collection.update_many(documents, silent=True) + except DocumentUpdateError as e: + self.log.error("Failed to update documents: %s", str(e)) + raise + + def replace_documents(self, collection_name, documents): + if not self.db_conn.has_collection(collection_name): + raise AirflowException(f"Collection does not exist: {collection_name}") + + try: + collection = self.db_conn.collection(collection_name) + collection.replace_many(documents, silent=True) + except DocumentReplaceError as e: + self.log.error("Failed to replace documents: %s", str(e)) + raise + + def delete_documents(self, collection_name, documents): + if not self.db_conn.has_collection(collection_name): + raise AirflowException(f"Collection does not exist: {collection_name}") + + try: + collection = self.db_conn.collection(collection_name) + collection.delete_many(documents, silent=True) + except DocumentDeleteError as e: + self.log.error("Failed to delete documents: %s", str(e)) + raise + @classmethod def get_ui_field_behaviour(cls) -> dict[str, Any]: return { diff --git a/providers/src/airflow/providers/arangodb/operators/arangodb.py b/providers/src/airflow/providers/arangodb/operators/arangodb.py index ac6bfe67612dc..cb8257495e431 100644 --- a/providers/src/airflow/providers/arangodb/operators/arangodb.py +++ b/providers/src/airflow/providers/arangodb/operators/arangodb.py @@ -18,8 +18,9 @@ from __future__ import annotations from collections.abc import Sequence -from typing import TYPE_CHECKING, Callable +from typing import TYPE_CHECKING, Any, Callable +from airflow.exceptions import AirflowException from airflow.models import BaseOperator from airflow.providers.arangodb.hooks.arangodb import ArangoDBHook @@ -64,3 +65,88 @@ def execute(self, context: Context): result = hook.query(self.query) if self.result_processor: self.result_processor(result) + + +class ArangoDBCollectionOperator(BaseOperator): + """ + Executes collection operations in a ArangoDB database. + + :param arangodb_conn_id: Connection ID for ArangoDB, defaults to "arangodb_default". + :param collection_name: The name of the collection to be operated on. + :param documents_to_insert: A list of python dictionaries to insert into the collection. + :param documents_to_update: A list of python dictionaries to update in the collection. + :param documents_to_replace: A list of python dictionaries to replace in the collection. + :param documents_to_delete: A list of python dictionaries to delete from the collection. + :param delete_collection: If True, the specified collection will be deleted. + """ + + def __init__( + self, + *, + arangodb_conn_id: str = "arangodb_default", + collection_name: str, + documents_to_insert: list[dict[str, Any]] | None = None, + documents_to_update: list[dict[str, Any]] | None = None, + documents_to_replace: list[dict[str, Any]] | None = None, + documents_to_delete: list[dict[str, Any]] | None = None, + delete_collection: bool = False, + **kwargs, + ): + super().__init__(**kwargs) + self.arangodb_conn_id = arangodb_conn_id + self.collection_name = collection_name + self.documents_to_insert = documents_to_insert or [] + self.documents_to_update = documents_to_update or [] + self.documents_to_replace = documents_to_replace or [] + self.documents_to_delete = documents_to_delete or [] + self.delete_collection = delete_collection + + def execute(self, context: Context): + hook = ArangoDBHook(arangodb_conn_id=self.arangodb_conn_id) + + if not any( + [ + self.documents_to_insert, + self.documents_to_update, + self.documents_to_replace, + self.documents_to_delete, + self.delete_collection, + ] + ): + raise AirflowException("At least one operation must be specified.") + + if self.documents_to_insert: + self.log.info( + "Inserting %d documents into collection '%s'.", + len(self.documents_to_insert), + self.collection_name, + ) + hook.insert_documents(self.collection_name, self.documents_to_insert) + + if self.documents_to_update: + self.log.info( + "Updating %d documents in collection '%s'.", + len(self.documents_to_update), + self.collection_name, + ) + hook.update_documents(self.collection_name, self.documents_to_update) + + if self.documents_to_replace: + self.log.info( + "Replacing %d documents in collection '%s'.", + len(self.documents_to_replace), + self.collection_name, + ) + hook.replace_documents(self.collection_name, self.documents_to_replace) + + if self.documents_to_delete: + self.log.info( + "Deleting %d documents from collection '%s'.", + len(self.documents_to_delete), + self.collection_name, + ) + hook.delete_documents(self.collection_name, self.documents_to_delete) + + if self.delete_collection: + self.log.info("Deleting collection '%s'.", self.collection_name) + hook.delete_collection(self.collection_name) diff --git a/providers/tests/arangodb/hooks/test_arangodb.py b/providers/tests/arangodb/hooks/test_arangodb.py index 4b663d86dc3da..6a8d2916d95fa 100644 --- a/providers/tests/arangodb/hooks/test_arangodb.py +++ b/providers/tests/arangodb/hooks/test_arangodb.py @@ -19,6 +19,7 @@ from unittest.mock import Mock, patch import pytest +from arango.cursor import Cursor from airflow.models import Connection from airflow.providers.arangodb.hooks.arangodb import ArangoDBHook @@ -67,6 +68,9 @@ def test_get_conn(self, arango_mock): def test_query(self, arango_mock): arangodb_hook = ArangoDBHook() with patch.object(arangodb_hook, "db_conn"): + mock_cursor = Mock(spec=Cursor) + arangodb_hook.db_conn.aql.execute.return_value = mock_cursor + arangodb_query = "FOR doc IN students RETURN doc" arangodb_hook.query(arangodb_query) diff --git a/providers/tests/arangodb/operators/test_arangodb.py b/providers/tests/arangodb/operators/test_arangodb.py index db884411a3f8d..9963f007a384a 100644 --- a/providers/tests/arangodb/operators/test_arangodb.py +++ b/providers/tests/arangodb/operators/test_arangodb.py @@ -18,7 +18,7 @@ from unittest import mock -from airflow.providers.arangodb.operators.arangodb import AQLOperator +from airflow.providers.arangodb.operators.arangodb import AQLOperator, ArangoDBCollectionOperator class TestAQLOperator: @@ -29,3 +29,21 @@ def test_arangodb_operator_test(self, mock_hook): op.execute(mock.MagicMock()) mock_hook.assert_called_once_with(arangodb_conn_id="arangodb_default") mock_hook.return_value.query.assert_called_once_with(arangodb_query) + + +class TestArangoDBCollectionOperator: + @mock.patch("airflow.providers.arangodb.operators.arangodb.ArangoDBHook") + def test_insert_documents(self, mock_hook): + documents_to_insert = [{"_key": "lola", "first": "Lola", "last": "Martin"}] + op = ArangoDBCollectionOperator( + task_id="insert_task", + collection_name="students", + documents_to_insert=documents_to_insert, + documents_to_update=None, + documents_to_replace=None, + documents_to_delete=None, + delete_collection=False, + ) + op.execute(mock.MagicMock()) + mock_hook.assert_called_once_with(arangodb_conn_id="arangodb_default") + mock_hook.return_value.insert_documents.assert_called_once_with("students", documents_to_insert)