-
Notifications
You must be signed in to change notification settings - Fork 23
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
8276c27
commit 9fc25c8
Showing
4 changed files
with
455 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,369 @@ | ||
from __future__ import annotations | ||
import json | ||
from collections import Generator | ||
from concurrent.futures import Future | ||
from queue import Queue | ||
from threading import Lock, Semaphore | ||
from typing import Optional | ||
|
||
import requests | ||
|
||
from ravendb.exceptions.raven_exceptions import RavenException | ||
from ravendb.http.server_node import ServerNode | ||
from ravendb.http.raven_command import RavenCommand | ||
from ravendb.documents.operations.misc import GetOperationStateOperation | ||
from ravendb.documents.session.entity_to_json import EntityToJson | ||
from ravendb.documents.session.document_info import DocumentInfo | ||
from ravendb import DocumentStore, MetadataAsDictionary, constants | ||
from ravendb.documents.commands.batches import CommandType | ||
from ravendb.documents.commands.bulkinsert import GetNextOperationIdCommand, KillOperationCommand | ||
from ravendb.exceptions.documents.bulkinsert import BulkInsertAbortedException | ||
from ravendb.documents.identity.hilo import GenerateEntityIdOnTheClient | ||
|
||
|
||
class BulkInsertOperation: | ||
class _BufferExposer(Generator): | ||
def __init__(self): | ||
self._data = {} | ||
self._done = Future() | ||
self.output_stream = Future() | ||
self._yield_buffer_semaphore = Semaphore(0) | ||
self._flushed_buffers_queue = Queue() | ||
|
||
def enqueue_buffer(self, buffer: bytearray): | ||
self._flushed_buffers_queue.put(buffer) | ||
|
||
def send_data(self): # todo: prepare a send loop not just single one | ||
with self._yield_buffer_semaphore.acquire(): | ||
yield self._flushed_buffers_queue.get() | ||
|
||
def is_done(self) -> bool: | ||
return self._done.done() | ||
|
||
def done(self): | ||
self._done.set_result(None) | ||
|
||
def error_on_processing_request(self, exception: Exception): | ||
self._done.set_exception(exception) | ||
|
||
def error_on_request_start(self, exception: Exception): | ||
self.output_stream.set_exception(exception) | ||
|
||
class _BulkInsertCommand(RavenCommand[requests.Response]): | ||
def __init__(self, key: int, buffer_exposer: BulkInsertOperation._BufferExposer, node_tag: str): | ||
super().__init__(requests.Response) | ||
self._buffer_exposer = buffer_exposer | ||
self._key = key | ||
self._selected_node_tag = node_tag | ||
self.use_compression = False | ||
|
||
def create_request(self, node: ServerNode) -> requests.Request: | ||
return requests.Request( | ||
"POST", | ||
f"{node.url}/databases/{node.database}/bulk_insert?id={self._key}", | ||
data=self._buffer_exposer.send_data(), | ||
) | ||
|
||
def set_response(self, response: Optional[str], from_cache: bool) -> None: | ||
raise NotImplementedError("Not Implemented") | ||
|
||
def is_read_request(self) -> bool: | ||
return False | ||
|
||
def send(self, session: requests.Session, request: requests.Request) -> requests.Response: | ||
try: | ||
return super().send(session, request) | ||
except Exception as e: | ||
self._buffer_exposer.error_on_request_start(e) | ||
|
||
def __init__(self, database: str = None, store: DocumentStore = None): | ||
self.use_compression = False | ||
|
||
self._bulk_insert_execute_task = Future() | ||
self._first = True | ||
self._in_progress_command: Optional[CommandType] = None | ||
self._operation_id = -1 | ||
self._node_tag = None | ||
self._concurrent_check = 0 | ||
self._concurrent_check_lock = Lock() | ||
|
||
self._thread_pool_executor = store.thread_pool_executor | ||
self._conventions = store.conventions | ||
if not database or database.isspace(): | ||
self._throw_no_database() | ||
self._request_executor = store.get_request_executor(database) | ||
|
||
self._async_write = Future() | ||
self._async_write.set_result(None) | ||
|
||
self._max_size_in_buffer = 1024 * 1024 | ||
|
||
self._current_data_buffer = bytearray(self._max_size_in_buffer) | ||
|
||
# todo: self._time_series_batch_size = self._conventions.time_series_batch_size | ||
self._buffer_exposer = BulkInsertOperation._BufferExposer() | ||
|
||
self._generate_entity_id_on_the_client = GenerateEntityIdOnTheClient( | ||
self._request_executor.conventions, | ||
lambda entity: self._request_executor.conventions.generate_document_id(database, entity), | ||
) | ||
|
||
def __enter__(self): | ||
return self | ||
|
||
def __exit__(self, exc_type, exc_val, exc_tb): | ||
self._end_previous_command_if_needed() | ||
|
||
flush_ex = None | ||
|
||
if self._buffer_exposer.is_done(): | ||
return | ||
|
||
if self._current_data_buffer: | ||
try: | ||
self._current_data_buffer += bytearray("}", encoding="utf-8") | ||
self._async_write.result() | ||
|
||
buffer = self._current_data_buffer | ||
self._buffer_exposer.enqueue_buffer(buffer) | ||
except Exception as e: | ||
flush_ex = e | ||
|
||
self._buffer_exposer.done() | ||
|
||
if self._operation_id == -1: | ||
# closing without calling a single store | ||
return | ||
|
||
if self._bulk_insert_execute_task is not None: | ||
try: | ||
self._bulk_insert_execute_task.result() | ||
except Exception as e: | ||
self._throw_bulk_insert_aborted(e, flush_ex) | ||
|
||
def _throw_bulk_insert_aborted(self, e: Exception, flush_ex: Optional[Exception]): | ||
error_from_server = None | ||
try: | ||
error_from_server = self._get_exception_from_operation() | ||
except Exception: | ||
pass # server is probably down, will propagate the original exception | ||
|
||
if error_from_server is not None: | ||
raise error_from_server | ||
|
||
raise BulkInsertAbortedException("Failed to execute bulk insert", e or flush_ex) | ||
|
||
def _throw_no_database(self): | ||
raise RuntimeError( | ||
"Cannot start bulk insert operation without specifying a name of a database to operate on. " | ||
"Database name can be passed as an argument when bulk insert is being created or " | ||
"default database can be defined using 'DocumentStore.database' property." | ||
) | ||
|
||
def _wait_for_id(self): | ||
if self._operation_id != -1: | ||
return | ||
|
||
bulk_insert_get_id_request = GetNextOperationIdCommand() | ||
self._request_executor.execute_command(bulk_insert_get_id_request) | ||
self._operation_id = bulk_insert_get_id_request.result | ||
self._node_tag = bulk_insert_get_id_request._node_tag | ||
|
||
def store_by_metadata(self, entity: object, metadata: Optional[MetadataAsDictionary] = None) -> str: | ||
key = ( | ||
self._get_id(entity) | ||
if metadata is None or constants.Documents.Metadata.ID not in metadata | ||
else metadata[constants.Documents.Metadata.ID] | ||
) | ||
|
||
self.store(entity, key, metadata) | ||
return key | ||
|
||
def store(self, entity: object, key: str, metadata: Optional[MetadataAsDictionary] = None) -> None: | ||
try: | ||
with self._concurrency_check() as check: | ||
self._verify_valid_key(key) | ||
|
||
self._execute_before_store() | ||
if metadata is None: | ||
metadata = MetadataAsDictionary() | ||
|
||
if constants.Documents.Metadata.COLLECTION not in metadata: | ||
collection = self._request_executor.conventions.get_collection_name(entity) | ||
if collection is not None: | ||
metadata[constants.Documents.Metadata.COLLECTION] = collection | ||
|
||
if constants.Documents.Metadata.RAVEN_PYTHON_TYPE not in metadata: | ||
python_type = self._request_executor.conventions.get_python_class_name(entity.__class__) | ||
if python_type is not None: | ||
metadata[constants.Documents.Metadata.RAVEN_PYTHON_TYPE] = python_type | ||
|
||
self._end_previous_command_if_needed() | ||
|
||
try: | ||
if not self._first: | ||
self._write_comma() | ||
|
||
self._first = False | ||
self._in_progress_command = CommandType.NONE | ||
self._current_data_buffer += bytearray('{"Id":"', encoding="utf-8") | ||
self._write_string(key) | ||
self._current_data_buffer += bytearray('","Type":"PUT","Document":', encoding="utf-8") | ||
|
||
self._flush_if_needed() | ||
|
||
document_info = DocumentInfo(metadata_instance=metadata) | ||
json_dict = EntityToJson.convert_entity_to_json_internal_static( | ||
entity, self._conventions, document_info, True | ||
) | ||
|
||
self._current_data_buffer += bytearray(json.dumps(json_dict), encoding="utf-8") | ||
self._current_data_buffer += bytearray("}", encoding="utf-8") | ||
except Exception as e: | ||
self._handle_errors(key, e) | ||
finally: | ||
with self._concurrent_check_lock: | ||
self._concurrent_check = 0 | ||
|
||
def _handle_errors(self, document_id: str, e: Exception) -> None: | ||
error = self._get_exception_from_operation() | ||
if error is not None: | ||
raise error | ||
|
||
self._throw_on_unavailable_stream(document_id, e) | ||
|
||
def _concurrency_check(self): | ||
with self._concurrent_check_lock: | ||
if not self._concurrent_check == 0: | ||
raise RuntimeError("Bulk Insert store methods cannot be executed concurrently.") | ||
self._concurrent_check = 1 | ||
|
||
def __return_func(): | ||
with self._concurrent_check_lock: | ||
if self._concurrent_check == 1: | ||
self._concurrent_check = 0 | ||
|
||
return __return_func | ||
|
||
def _flush_if_needed(self) -> None: | ||
if len(self._current_data_buffer) > self._max_size_in_buffer or self._async_write.done(): | ||
self._async_write.result() | ||
|
||
buffer = self._current_data_buffer | ||
self._current_data_buffer.clear() | ||
|
||
# todo: check if it's better to create a new bytearray of max size or clear it (possible dealloc) | ||
|
||
def __async_write(): | ||
self._buffer_exposer.enqueue_buffer(buffer) | ||
return None | ||
|
||
self._async_write = self._thread_pool_executor.submit(__async_write) | ||
|
||
def _end_previous_command_if_needed(self) -> None: | ||
if self._in_progress_command == CommandType.COUNTERS: | ||
pass # todo: counters | ||
elif self._in_progress_command == CommandType.TIME_SERIES: | ||
pass # todo: time series | ||
|
||
def _write_string(self, input_string: str) -> None: | ||
for i in range(len(input_string)): | ||
c = input_string[i] | ||
if '"' == c: | ||
if i == 0 or input_string[i - 1] != "\\": | ||
self._current_data_buffer += bytearray("\\", encoding="utf-8") | ||
|
||
self._current_data_buffer += bytearray(c, encoding="utf-8") | ||
|
||
def _write_comma(self) -> None: | ||
self._current_data_buffer += bytearray(",", encoding="utf-8") | ||
|
||
def _execute_before_store(self) -> None: | ||
if self._bulk_insert_execute_task is None: # todo: check if it's valid way | ||
self._wait_for_id() | ||
self._ensure_execute_task() | ||
|
||
if ( | ||
self._bulk_insert_execute_task.exception() | ||
): # todo: check if isCompletedExceptionally returns false if task isn't finished | ||
try: | ||
self._bulk_insert_execute_task.result() | ||
except Exception as e: | ||
self._throw_bulk_insert_aborted(e, None) | ||
|
||
@staticmethod | ||
def _verify_valid_id(key: str) -> None: | ||
if not key or key.isspace(): | ||
raise ValueError("Document id must have a non empty value") | ||
|
||
if key.endswith("|"): | ||
raise RuntimeError(f"Document ids cannot end with '|', but was called with {key}") | ||
|
||
def _get_exception_from_operation(self) -> Optional[BulkInsertAbortedException]: | ||
state_request = GetOperationStateOperation.GetOperationStateCommand(self._operation_id, self._node_tag) | ||
self._request_executor.execute_command(state_request) | ||
|
||
if "Faulted" != state_request.result["Status"]: | ||
return None | ||
|
||
result = state_request.result["Result"] | ||
|
||
if result["$type"].starts_with("Raven.Client.Documents.Operations.OperationExceptionResult"): | ||
return BulkInsertAbortedException(result["Error"]) | ||
|
||
return None | ||
|
||
def _ensure_execute_task(self) -> None: | ||
try: | ||
bulk_command = BulkInsertOperation._BulkInsertCommand( | ||
self._operation_id, self._buffer_exposer, self._node_tag | ||
) | ||
bulk_command.use_compression = self.use_compression | ||
|
||
def __core_async(): | ||
self._request_executor.execute_command(bulk_command) | ||
return None | ||
|
||
self._bulk_insert_execute_task = self._thread_pool_executor.submit(__core_async) | ||
self._current_data_buffer += bytearray("[", encoding="utf-8") | ||
|
||
except Exception as e: | ||
raise RavenException("Unable to open bulk insert stream", e) | ||
|
||
def _throw_on_unavailable_stream(self, key: str, inner_ex: Exception) -> None: | ||
self._buffer_exposer.error_on_processing_request( | ||
BulkInsertAbortedException(f"Write to stream failed at document with id {key}", inner_ex) | ||
) | ||
|
||
def abort(self) -> None: | ||
if self._operation_id == -1: | ||
return # nothing was done, nothing to kill | ||
|
||
self._wait_for_id() | ||
|
||
try: | ||
self._request_executor.execute_command(KillOperationCommand(self._operation_id, self._node_tag)) | ||
except RavenException as e: | ||
raise BulkInsertAbortedException( | ||
"Unable to kill this bulk insert operation, because it was not found on the server", e | ||
) | ||
|
||
def _get_id(self, entity: object) -> str: | ||
success, key = self._generate_entity_id_on_the_client.try_get_id_from_instance(entity) | ||
if success: | ||
return key | ||
|
||
key = self._generate_entity_id_on_the_client.generate_document_key_for_storage(entity) | ||
|
||
self._generate_entity_id_on_the_client.try_set_identity(entity, key) | ||
return key | ||
|
||
# todo: attachments_for | ||
# todo: time_series_for | ||
# todo: CountersBulkInsert | ||
# todo: CountersBulkInsertOperation | ||
# todo: TimeSeriesBulkInsertBase | ||
# todo: TimeSeriesBulkInsert | ||
# todo: TypedTimeSeriesBulkInsert | ||
# todo: AttachmentsBulkInsert | ||
# todo: AttachmentsBulkInsertOperation |
Oops, something went wrong.