Skip to content

Commit

Permalink
Added the ArangoDBCollectionOperator that executes collection operati…
Browse files Browse the repository at this point in the history
…ons in a ArangoDB database (#44676)

* - Added the ArangoDBCollectionOperator that executes collection operations in a ArangoDB database
- Insert, replace, update and delete documents functionality included

* Fixed a failing test in hooks
  • Loading branch information
harjeevanmaan authored Dec 6, 2024
1 parent f326e47 commit 258ef9d
Show file tree
Hide file tree
Showing 4 changed files with 172 additions and 5 deletions.
65 changes: 62 additions & 3 deletions providers/src/airflow/providers/arangodb/hooks/arangodb.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,18 @@
from typing import TYPE_CHECKING, Any

from arango import AQLQueryExecuteError, ArangoClient as ArangoDBClient
from arango.cursor import Cursor
from arango.exceptions import (
DocumentDeleteError,
DocumentInsertError,
DocumentReplaceError,
DocumentUpdateError,
)

from airflow.exceptions import AirflowException
from airflow.hooks.base import BaseHook

if TYPE_CHECKING:
from arango.cursor import Cursor
from arango.database import StandardDatabase

from airflow.models import Connection
Expand Down Expand Up @@ -101,8 +107,8 @@ def query(self, query, **kwargs) -> Cursor:
try:
if self.db_conn:
result = self.db_conn.aql.execute(query, **kwargs)
if TYPE_CHECKING:
assert isinstance(result, Cursor)
if not isinstance(result, Cursor):
raise AirflowException("Failed to execute AQLQuery, expected result to be of type Cursor")
return result
else:
raise AirflowException(
Expand All @@ -113,12 +119,21 @@ def query(self, query, **kwargs) -> Cursor:

def create_collection(self, name):
if not self.db_conn.has_collection(name):
self.log.info("Collection '%s' does not exist. Creating a new collection.", name)
self.db_conn.create_collection(name)
return True
else:
self.log.info("Collection already exists: %s", name)
return False

def delete_collection(self, name):
if self.db_conn.has_collection(name):
self.db_conn.delete_collection(name)
return True
else:
self.log.info("Collection does not exist: %s", name)
return False

def create_database(self, name):
if not self.db_conn.has_database(name):
self.db_conn.create_database(name)
Expand All @@ -135,6 +150,50 @@ def create_graph(self, name):
self.log.info("Graph already exists: %s", name)
return False

def insert_documents(self, collection_name, documents):
if not self.db_conn.has_collection(collection_name):
self.create_collection(collection_name)

try:
collection = self.db_conn.collection(collection_name)
collection.insert_many(documents, silent=True)
except DocumentInsertError as e:
self.log.error("Failed to insert documents: %s", str(e))
raise

def update_documents(self, collection_name, documents):
if not self.db_conn.has_collection(collection_name):
raise AirflowException(f"Collection does not exist: {collection_name}")

try:
collection = self.db_conn.collection(collection_name)
collection.update_many(documents, silent=True)
except DocumentUpdateError as e:
self.log.error("Failed to update documents: %s", str(e))
raise

def replace_documents(self, collection_name, documents):
if not self.db_conn.has_collection(collection_name):
raise AirflowException(f"Collection does not exist: {collection_name}")

try:
collection = self.db_conn.collection(collection_name)
collection.replace_many(documents, silent=True)
except DocumentReplaceError as e:
self.log.error("Failed to replace documents: %s", str(e))
raise

def delete_documents(self, collection_name, documents):
if not self.db_conn.has_collection(collection_name):
raise AirflowException(f"Collection does not exist: {collection_name}")

try:
collection = self.db_conn.collection(collection_name)
collection.delete_many(documents, silent=True)
except DocumentDeleteError as e:
self.log.error("Failed to delete documents: %s", str(e))
raise

@classmethod
def get_ui_field_behaviour(cls) -> dict[str, Any]:
return {
Expand Down
88 changes: 87 additions & 1 deletion providers/src/airflow/providers/arangodb/operators/arangodb.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,9 @@
from __future__ import annotations

from collections.abc import Sequence
from typing import TYPE_CHECKING, Callable
from typing import TYPE_CHECKING, Any, Callable

from airflow.exceptions import AirflowException
from airflow.models import BaseOperator
from airflow.providers.arangodb.hooks.arangodb import ArangoDBHook

Expand Down Expand Up @@ -64,3 +65,88 @@ def execute(self, context: Context):
result = hook.query(self.query)
if self.result_processor:
self.result_processor(result)


class ArangoDBCollectionOperator(BaseOperator):
"""
Executes collection operations in a ArangoDB database.
:param arangodb_conn_id: Connection ID for ArangoDB, defaults to "arangodb_default".
:param collection_name: The name of the collection to be operated on.
:param documents_to_insert: A list of python dictionaries to insert into the collection.
:param documents_to_update: A list of python dictionaries to update in the collection.
:param documents_to_replace: A list of python dictionaries to replace in the collection.
:param documents_to_delete: A list of python dictionaries to delete from the collection.
:param delete_collection: If True, the specified collection will be deleted.
"""

def __init__(
self,
*,
arangodb_conn_id: str = "arangodb_default",
collection_name: str,
documents_to_insert: list[dict[str, Any]] | None = None,
documents_to_update: list[dict[str, Any]] | None = None,
documents_to_replace: list[dict[str, Any]] | None = None,
documents_to_delete: list[dict[str, Any]] | None = None,
delete_collection: bool = False,
**kwargs,
):
super().__init__(**kwargs)
self.arangodb_conn_id = arangodb_conn_id
self.collection_name = collection_name
self.documents_to_insert = documents_to_insert or []
self.documents_to_update = documents_to_update or []
self.documents_to_replace = documents_to_replace or []
self.documents_to_delete = documents_to_delete or []
self.delete_collection = delete_collection

def execute(self, context: Context):
hook = ArangoDBHook(arangodb_conn_id=self.arangodb_conn_id)

if not any(
[
self.documents_to_insert,
self.documents_to_update,
self.documents_to_replace,
self.documents_to_delete,
self.delete_collection,
]
):
raise AirflowException("At least one operation must be specified.")

if self.documents_to_insert:
self.log.info(
"Inserting %d documents into collection '%s'.",
len(self.documents_to_insert),
self.collection_name,
)
hook.insert_documents(self.collection_name, self.documents_to_insert)

if self.documents_to_update:
self.log.info(
"Updating %d documents in collection '%s'.",
len(self.documents_to_update),
self.collection_name,
)
hook.update_documents(self.collection_name, self.documents_to_update)

if self.documents_to_replace:
self.log.info(
"Replacing %d documents in collection '%s'.",
len(self.documents_to_replace),
self.collection_name,
)
hook.replace_documents(self.collection_name, self.documents_to_replace)

if self.documents_to_delete:
self.log.info(
"Deleting %d documents from collection '%s'.",
len(self.documents_to_delete),
self.collection_name,
)
hook.delete_documents(self.collection_name, self.documents_to_delete)

if self.delete_collection:
self.log.info("Deleting collection '%s'.", self.collection_name)
hook.delete_collection(self.collection_name)
4 changes: 4 additions & 0 deletions providers/tests/arangodb/hooks/test_arangodb.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from unittest.mock import Mock, patch

import pytest
from arango.cursor import Cursor

from airflow.models import Connection
from airflow.providers.arangodb.hooks.arangodb import ArangoDBHook
Expand Down Expand Up @@ -67,6 +68,9 @@ def test_get_conn(self, arango_mock):
def test_query(self, arango_mock):
arangodb_hook = ArangoDBHook()
with patch.object(arangodb_hook, "db_conn"):
mock_cursor = Mock(spec=Cursor)
arangodb_hook.db_conn.aql.execute.return_value = mock_cursor

arangodb_query = "FOR doc IN students RETURN doc"
arangodb_hook.query(arangodb_query)

Expand Down
20 changes: 19 additions & 1 deletion providers/tests/arangodb/operators/test_arangodb.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

from unittest import mock

from airflow.providers.arangodb.operators.arangodb import AQLOperator
from airflow.providers.arangodb.operators.arangodb import AQLOperator, ArangoDBCollectionOperator


class TestAQLOperator:
Expand All @@ -29,3 +29,21 @@ def test_arangodb_operator_test(self, mock_hook):
op.execute(mock.MagicMock())
mock_hook.assert_called_once_with(arangodb_conn_id="arangodb_default")
mock_hook.return_value.query.assert_called_once_with(arangodb_query)


class TestArangoDBCollectionOperator:
@mock.patch("airflow.providers.arangodb.operators.arangodb.ArangoDBHook")
def test_insert_documents(self, mock_hook):
documents_to_insert = [{"_key": "lola", "first": "Lola", "last": "Martin"}]
op = ArangoDBCollectionOperator(
task_id="insert_task",
collection_name="students",
documents_to_insert=documents_to_insert,
documents_to_update=None,
documents_to_replace=None,
documents_to_delete=None,
delete_collection=False,
)
op.execute(mock.MagicMock())
mock_hook.assert_called_once_with(arangodb_conn_id="arangodb_default")
mock_hook.return_value.insert_documents.assert_called_once_with("students", documents_to_insert)

0 comments on commit 258ef9d

Please sign in to comment.