Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add OperationsPreprocessor to queueing #1485

Merged
merged 10 commits into from
Oct 6, 2023
2 changes: 1 addition & 1 deletion src/neptune/internal/backends/hosted_neptune_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,7 +457,7 @@ def execute_operations(
dropped_count = operations_batch.dropped_operations_count

operations_preprocessor = OperationsPreprocessor()
operations_preprocessor.process(operations_batch.operations)
operations_preprocessor.process_batch(operations_batch.operations)

preprocessed_operations = operations_preprocessor.get_operations()
errors.extend(preprocessed_operations.errors)
Expand Down
75 changes: 59 additions & 16 deletions src/neptune/internal/backends/operations_preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from typing import (
Callable,
List,
Type,
TypeVar,
)

Expand All @@ -44,6 +45,7 @@
DeleteFiles,
LogFloats,
LogImages,
LogOperation,
LogStrings,
Operation,
RemoveStrings,
Expand All @@ -70,24 +72,41 @@ class AccumulatedOperations:

errors: List[MetadataInconsistency] = dataclasses.field(default_factory=list)

def all_operations(self) -> List[Operation]:
return self.upload_operations + self.artifact_operations + self.other_operations


class OperationsPreprocessor:
def __init__(self):
self._accumulators: typing.Dict[str, "_OperationsAccumulator"] = dict()
self.processed_ops_count = 0

def process(self, operations: List[Operation]):
self.final_ops_count = 0
self.final_append_count = 0

def process(self, operation: Operation) -> bool:
"""Adds a single operation to its processed list.
Returns `False` iff the new operation can't be in queue until one of already enqueued operations gets
synchronized with server first.
"""
try:
self._process_op(operation)
self.processed_ops_count += 1
return True
except RequiresPreviousCompleted:
return False

def process_batch(self, operations: List[Operation]) -> None:
for op in operations:
try:
self._process_op(op)
self.processed_ops_count += 1
except RequiresPreviousCompleted:
if not self.process(op):
return

def _process_op(self, op: Operation) -> "_OperationsAccumulator":
path_str = path_to_str(op.path)
target_acc = self._accumulators.setdefault(path_str, _OperationsAccumulator(op.path))
old_ops_count, old_append_count = target_acc.get_op_count(), target_acc.get_append_count()
target_acc.visit(op)
self.final_ops_count += target_acc.get_op_count() - old_ops_count
self.final_append_count += target_acc.get_append_count() - old_append_count
return target_acc

@staticmethod
Expand Down Expand Up @@ -143,13 +162,21 @@ def __init__(self, path: List[str]):
self._modify_ops = []
self._config_ops = []
self._errors = []
self._ops_count = 0
self._append_count = 0

def get_operations(self) -> List[Operation]:
return self._delete_ops + self._modify_ops + self._config_ops

def get_errors(self) -> List[MetadataInconsistency]:
return self._errors

def get_op_count(self) -> int:
return self._ops_count

def get_append_count(self) -> int:
return self._append_count

def _check_prerequisites(self, op: Operation):
if (OperationsPreprocessor.is_file_op(op) or OperationsPreprocessor.is_artifact_op(op)) and len(
self._delete_ops
Expand Down Expand Up @@ -179,7 +206,9 @@ def _process_modify_op(
else:
self._check_prerequisites(op)
self._type = expected_type
old_op_count = len(self._modify_ops)
self._modify_ops = modifier(self._modify_ops, op)
self._ops_count += len(self._modify_ops) - old_op_count

def _process_config_op(self, expected_type: _DataType, op: Operation) -> None:

Expand All @@ -199,7 +228,9 @@ def _process_config_op(self, expected_type: _DataType, op: Operation) -> None:
else:
self._check_prerequisites(op)
self._type = expected_type
old_op_count = len(self._config_ops)
self._config_ops = [op]
self._ops_count += len(self._config_ops) - old_op_count

def visit_assign_float(self, op: AssignFloat) -> None:
self._process_modify_op(_DataType.FLOAT, op, self._assign_modifier())
Expand Down Expand Up @@ -295,6 +326,8 @@ def visit_delete_attribute(self, op: DeleteAttribute) -> None:
self._modify_ops = []
self._config_ops = []
self._type = None
self._ops_count = len(self._delete_ops)
self._append_count = 0
else:
# This case is tricky. There was no delete operation, but some modifications was performed.
# We do not know if this attribute exists on server side and we do not want a delete op to fail.
Expand All @@ -303,6 +336,8 @@ def visit_delete_attribute(self, op: DeleteAttribute) -> None:
self._modify_ops = []
self._config_ops = []
self._type = None
self._ops_count = len(self._delete_ops)
self._append_count = 0
else:
if self._delete_ops:
# Do nothing if there already is a delete operation
Expand All @@ -312,6 +347,7 @@ def visit_delete_attribute(self, op: DeleteAttribute) -> None:
# If value has not been set locally yet and no delete operation was performed,
# simply perform single delete operation.
self._delete_ops.append(op)
self._ops_count = len(self._delete_ops)

@staticmethod
def _artifact_log_modifier(
Expand Down Expand Up @@ -340,23 +376,30 @@ def visit_copy_attribute(self, op: CopyAttribute) -> None:
def _assign_modifier():
return lambda ops, new_op: [new_op]

@staticmethod
def _clear_modifier():
return lambda ops, new_op: [new_op]
def _clear_modifier(self):
def modifier(ops: List[Operation], new_op: Operation):
for op in ops:
if isinstance(op, LogOperation):
self._append_count -= op.value_count()
return [new_op]

@staticmethod
def _log_modifier(log_op_class: type, clear_op_class: type, log_combine: Callable[[T, T], T]):
def modifier(ops, new_op):
return modifier

def _log_modifier(self, log_op_class: Type[LogOperation], clear_op_class: type, log_combine: Callable[[T, T], T]):
def modifier(ops: List[Operation], new_op: Operation):
if len(ops) == 0:
return [new_op]
res = [new_op]
elif len(ops) == 1 and isinstance(ops[0], log_op_class):
return [log_combine(ops[0], new_op)]
res = [log_combine(ops[0], new_op)]
elif len(ops) == 1 and isinstance(ops[0], clear_op_class):
return [ops[0], new_op]
res = [ops[0], new_op]
elif len(ops) == 2:
return [ops[0], log_combine(ops[1], new_op)]
res = [ops[0], log_combine(ops[1], new_op)]
else:
raise InternalClientError("Preprocessing operations failed: len(ops) == {}".format(len(ops)))
if isinstance(new_op, log_op_class): # Check just so that static typing doesn't complain
self._append_count += new_op.value_count()
return res

return modifier

Expand Down
13 changes: 12 additions & 1 deletion src/neptune/internal/operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,9 @@ def from_dict(data: dict) -> "UploadFileSet":


class LogOperation(Operation, abc.ABC):
pass
@abc.abstractmethod
def value_count(self) -> int:
pass


@dataclass
Expand Down Expand Up @@ -332,6 +334,9 @@ def from_dict(data: dict) -> "LogFloats":
[LogFloats.ValueType.from_dict(value) for value in data["values"]],
)

def value_count(self) -> int:
return len(self.values)


@dataclass
class LogStrings(LogOperation):
Expand All @@ -355,6 +360,9 @@ def from_dict(data: dict) -> "LogStrings":
[LogStrings.ValueType.from_dict(value) for value in data["values"]],
)

def value_count(self) -> int:
return len(self.values)


@dataclass
class ImageValue:
Expand Down Expand Up @@ -400,6 +408,9 @@ def from_dict(data: dict) -> "LogImages":
[LogImages.ValueType.from_dict(value, ImageValue.deserializer) for value in data["values"]],
)

def value_count(self) -> int:
return len(self.values)


@dataclass
class ClearFloatLog(Operation):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,23 +26,32 @@
)
from typing import (
Callable,
ClassVar,
List,
Optional,
Tuple,
)

from neptune.constants import ASYNC_DIRECTORY
from neptune.envs import NEPTUNE_SYNC_AFTER_STOP_TIMEOUT
from neptune.exceptions import NeptuneSynchronizationAlreadyStoppedException
from neptune.internal.backends.neptune_backend import NeptuneBackend
from neptune.internal.backends.operations_preprocessor import OperationsPreprocessor
from neptune.internal.container_type import ContainerType
from neptune.internal.disk_queue import DiskQueue
from neptune.internal.disk_queue import (
DiskQueue,
QueueElement,
)
from neptune.internal.id_formats import UniqueId
from neptune.internal.init.parameters import (
ASYNC_LAG_THRESHOLD,
ASYNC_NO_PROGRESS_THRESHOLD,
DEFAULT_STOP_TIMEOUT,
)
from neptune.internal.operation import Operation
from neptune.internal.operation import (
CopyAttribute,
Operation,
)
from neptune.internal.operation_processors.operation_processor import OperationProcessor
from neptune.internal.operation_processors.operation_storage import (
OperationStorage,
Expand Down Expand Up @@ -255,6 +264,10 @@ def close(self):
self._queue.close()

class ConsumerThread(Daemon):
MAX_OPERATIONS_IN_BATCH: ClassVar[int] = 1000
MAX_APPENDS_IN_BATCH: ClassVar[int] = 100000
MAX_BATCH_SIZE_BYTES: ClassVar[int] = 100 * 1024 * 1024

def __init__(
self,
processor: "AsyncOperationProcessor",
Expand All @@ -266,6 +279,7 @@ def __init__(
self._batch_size = batch_size
self._last_flush = 0
self._no_progress_exceeded = False
self._last_disk_record: Optional[QueueElement[Operation]] = None

def run(self):
try:
Expand All @@ -282,10 +296,42 @@ def work(self) -> None:
self._processor._queue.flush()

while True:
batch = self._processor._queue.get_batch(self._batch_size)
batch = self.collect_batch()
if not batch:
return
self.process_batch([element.obj for element in batch], batch[-1].ver)
operations, version = batch
self.process_batch(operations, version)

def collect_batch(self) -> Optional[Tuple[List[Operation], int]]:
preprocessor = OperationsPreprocessor()
version: Optional[int] = None
total_bytes = 0
copy_ops: List[CopyAttribute] = []
while (
preprocessor.final_ops_count < self.MAX_OPERATIONS_IN_BATCH
and preprocessor.final_append_count < self.MAX_APPENDS_IN_BATCH
and total_bytes < self.MAX_BATCH_SIZE_BYTES
):
record: Optional[QueueElement[Operation]] = self._last_disk_record or self._processor._queue.get()
self._last_disk_record = None
if not record:
break
if isinstance(record.obj, CopyAttribute):
# CopyAttribute can be only at the start of a batch.
if copy_ops or preprocessor.final_ops_count:
self._last_disk_record = record
break
else:
version = record.ver
copy_ops.append(record.obj)
total_bytes += record.size
elif preprocessor.process(record.obj):
version = record.ver
total_bytes += record.size
else:
self._last_disk_record = record
break
return (copy_ops + preprocessor.get_operations().all_operations(), version) if version is not None else None

def _check_no_progress(self):
if not self._no_progress_exceeded:
Expand Down
Loading