From 4b86d6f59f1d5b28b5540af58d94810ad5f3c8f5 Mon Sep 17 00:00:00 2001 From: jnasselle Date: Wed, 3 Jan 2024 15:49:34 -0300 Subject: [PATCH] Several changes on workflow engine/processor - Add some simple examples - Improve code modularization: dag and file managment - Allow on-error feature - Fix dtt1-agent dependency: agent provision depends on manager provision - Improve logging mechanism --- poc-tests/modules/classes/schemaValidator.py | 1 - .../workflow_engine/examples/dtt1-agents.yaml | 25 +- .../multiple_linear_independant_flow.yaml | 69 ++++ ...e_linear_independant_flow_templetized.yaml | 55 +++ .../examples/single_linear_flow.yaml | 38 ++ poc-tests/modules/workflow_engine/task.py | 43 +-- .../workflow_engine/workflow_processor.py | 353 +++++++++--------- poc-tests/scripts/workflow_engine.py | 16 +- 8 files changed, 391 insertions(+), 209 deletions(-) create mode 100644 poc-tests/modules/workflow_engine/examples/multiple_linear_independant_flow.yaml create mode 100644 poc-tests/modules/workflow_engine/examples/multiple_linear_independant_flow_templetized.yaml create mode 100644 poc-tests/modules/workflow_engine/examples/single_linear_flow.yaml diff --git a/poc-tests/modules/classes/schemaValidator.py b/poc-tests/modules/classes/schemaValidator.py index d4f2475430..5ef60f2179 100644 --- a/poc-tests/modules/classes/schemaValidator.py +++ b/poc-tests/modules/classes/schemaValidator.py @@ -41,7 +41,6 @@ def preprocess_data(self): def validateSchema(self): try: jsonschema.validate(self.yamlData, self.schemaData) - print("YAML is valid!") except jsonschema.exceptions.ValidationError as e: print(f"Validation error: {e}") except Exception as e: diff --git a/poc-tests/modules/workflow_engine/examples/dtt1-agents.yaml b/poc-tests/modules/workflow_engine/examples/dtt1-agents.yaml index 5b6c874644..e8a827282d 100644 --- a/poc-tests/modules/workflow_engine/examples/dtt1-agents.yaml +++ b/poc-tests/modules/workflow_engine/examples/dtt1-agents.yaml @@ -35,78 +35,85 @@ variables: tasks: # Generic agent test task - - task: "run-agent-tests-{agent}" + - task: "test-agent-{agent}" description: "Run tests for the {agent} agent." do: this: process with: path: /bin/echo args: + - -n - "Running tests for {agent}" depends-on: - - "provision-{agent}" - - "provision-manager" + - "provision-agent-{agent}" foreach: - variable: agents-os as: agent # Unique manager provision task - - task: "provision-manager" + - task: "provision-manager-{manager-os}" description: "Provision the manager." do: this: process with: path: /bin/echo args: + - -n - "Running provision for manager" depends-on: - - "allocate-manager" + - "allocate-manager-{manager-os}" # Unique manager allocate task - - task: "allocate-manager" + - task: "allocate-manager-{manager-os}" description: "Allocate resources for the manager." do: this: process with: path: /bin/echo args: + - -n - "Running allocate for manager" cleanup: this: process with: path: /bin/echo args: + - -n - "Running cleanup for manager" # Generic agent provision task - - task: "provision-{agent}" + - task: "provision-agent-{agent}" description: "Provision resources for the {agent} agent." do: this: process with: path: /bin/echo args: + - -n - "Running provision for {agent}" depends-on: - - "allocate-{agent}" + - "allocate-agent-{agent}" + - "provision-manager-{manager-os}" foreach: - variable: agents-os as: agent # Generic agent allocate task - - task: "allocate-{agent}" + - task: "allocate-agent-{agent}" description: "Allocate resources for the {agent} agent." do: this: process with: path: /bin/echo args: + - -n - "Running allocate for {agent}" cleanup: this: process with: path: /bin/echo args: + - -n - "Running cleanup for allocate for {agent}" foreach: - variable: agents-os diff --git a/poc-tests/modules/workflow_engine/examples/multiple_linear_independant_flow.yaml b/poc-tests/modules/workflow_engine/examples/multiple_linear_independant_flow.yaml new file mode 100644 index 0000000000..7f0ae53560 --- /dev/null +++ b/poc-tests/modules/workflow_engine/examples/multiple_linear_independant_flow.yaml @@ -0,0 +1,69 @@ +# Copyright (C) 2015, Wazuh Inc. +# Created by Wazuh, Inc. . +# This program is a free software; you can redistribute it and/or modify it under the terms of GPLv2 +version: 0.1 +description: This is a basic example of two linear and independant flows. + +tasks: + - task: "A1" + description: "Task A1" + do: + this: process + with: + path: /bin/echo + args: + - -n + - "Running task A1" + depends-on: + - "B1" + - task: "B1" + description: "Task B1" + do: + this: process + with: + path: /bin/echo + args: + - -n + - "Running task B1" + depends-on: + - "C1" + - task: "C1" + description: "Task C1" + do: + this: process + with: + path: /bin/echo + args: + - -n + - "Running task C1" + - task: "A2" + description: "Task A2" + do: + this: process + with: + path: /bin/echo + args: + - -n + - "Running task A2" + depends-on: + - "B2" + - task: "B2" + description: "Task B2" + do: + this: process + with: + path: /bin/echo + args: + - -n + - "Running task B2" + depends-on: + - "C2" + - task: "C2" + description: "Task C2" + do: + this: process + with: + path: /bin/echo + args: + - -n + - "Running task C2" diff --git a/poc-tests/modules/workflow_engine/examples/multiple_linear_independant_flow_templetized.yaml b/poc-tests/modules/workflow_engine/examples/multiple_linear_independant_flow_templetized.yaml new file mode 100644 index 0000000000..b4c55c96d7 --- /dev/null +++ b/poc-tests/modules/workflow_engine/examples/multiple_linear_independant_flow_templetized.yaml @@ -0,0 +1,55 @@ +# Copyright (C) 2015, Wazuh Inc. +# Created by Wazuh, Inc. . +# This program is a free software; you can redistribute it and/or modify it under the terms of GPLv2 +version: 0.1 +description: This is a basic example of two linear and independant flows using templates. + +variables: + index: + - 1 + - 2 + - 3 + - 4 + - 5 + - 6 +tasks: + - task: "A{i}" + description: "Task A{i}" + do: + this: process + with: + path: /bin/echo + args: + - -n + - "Running task A{i}" + depends-on: + - "B{i}" + foreach: + - variable: index + as: i + - task: "B{i}" + description: "Task B{i}" + do: + this: process + with: + path: /bin/echo + args: + - -n + - "Running task B{i}" + depends-on: + - "C{i}" + foreach: + - variable: index + as: i + - task: "C{i}" + description: "Task C{i}" + do: + this: process + with: + path: /bin/echo + args: + - -n + - "Running task C{i}" + foreach: + - variable: index + as: i diff --git a/poc-tests/modules/workflow_engine/examples/single_linear_flow.yaml b/poc-tests/modules/workflow_engine/examples/single_linear_flow.yaml new file mode 100644 index 0000000000..25c6965a00 --- /dev/null +++ b/poc-tests/modules/workflow_engine/examples/single_linear_flow.yaml @@ -0,0 +1,38 @@ +# Copyright (C) 2015, Wazuh Inc. +# Created by Wazuh, Inc. . +# This program is a free software; you can redistribute it and/or modify it under the terms of GPLv2 +version: 0.1 +description: This is a basic example of linear flow. + +tasks: + - task: "A" + description: "Task A" + do: + this: process + with: + path: /bin/echo + args: + - -n + - "Running task A" + depends-on: + - "B" + - task: "B" + description: "Task B" + do: + this: process + with: + path: /bin/echo + args: + - -n + - "Running task B" + depends-on: + - "C" + - task: "C" + description: "Task C" + do: + this: process + with: + path: /bin/echo + args: + - -n + - "Running task C" diff --git a/poc-tests/modules/workflow_engine/task.py b/poc-tests/modules/workflow_engine/task.py index e1b8440649..a89f0fb431 100644 --- a/poc-tests/modules/workflow_engine/task.py +++ b/poc-tests/modules/workflow_engine/task.py @@ -8,6 +8,7 @@ import random import time +logger = (lambda: logging.getLogger())() class Task(ABC): """Abstract base class for tasks.""" @@ -21,7 +22,7 @@ def execute(self) -> None: class ProcessTask(Task): """Task for executing a process.""" - def __init__(self, task_name: str, task_parameters: dict, logger: logging.Logger): + def __init__(self, task_name: str, task_parameters: dict): """ Initialize ProcessTask. @@ -44,48 +45,44 @@ def format_key_value(task_arg): task_args = [str(task_arg) if isinstance(task_arg, str) else format_key_value(task_arg) for task_arg in self.task_parameters['args']] - try: - result = subprocess.run( - [self.task_parameters['path']] + task_args, - check=True, - capture_output=True, - text=True, - ) + result = subprocess.run( + [self.task_parameters['path']] + task_args, + check=True, + capture_output=True, + text=True, + ) - self.logger.info("Output:\n%s", result.stdout, extra={'tag': self.task_name}) - - if result.returncode != 0: - raise subprocess.CalledProcessError(returncode=result.returncode, cmd=result.args, output=result.stdout) - - except Exception as e: - self.logger.error("Task failed with error: %s", e, extra={'tag': self.task_name}) - # Handle the exception or re-raise if necessary - raise + if result.returncode != 0: + raise subprocess.CalledProcessError(returncode=result.returncode, cmd=result.args, output=result.stdout) class DummyTask(Task): - def __init__(self, task_name, task_parameters, logger: logging.Logger): + def __init__(self, task_name, task_parameters): self.task_name = task_name self.task_parameters = task_parameters - self.logger = logger def execute(self): message = self.task_parameters.get('message', 'No message provided') - self.logger.info("%s: %s", message, self.task_name, extra={'tag': self.task_name}) + logger.info("%s: %s", message, self.task_name, extra={'tag': self.task_name}) class DummyRandomTask(Task): - def __init__(self, task_name, task_parameters, logger: logging.Logger): + def __init__(self, task_name, task_parameters): self.task_name = task_name self.task_parameters = task_parameters - self.logger = logger def execute(self): time_interval = self.task_parameters.get('time-seconds', [1, 5]) sleep_time = random.uniform(time_interval[0], time_interval[1]) message = self.task_parameters.get('message', 'No message provided') - self.logger.info("%s: %s (Sleeping for %.2f seconds)", message, self.task_name, sleep_time, extra={'tag': self.task_name}) + logger.info("%s: %s (Sleeping for %.2f seconds)", message, self.task_name, sleep_time, extra={'tag': self.task_name}) time.sleep(sleep_time) + +TASKS_HANDLERS = { + 'process': ProcessTask, + 'dummy': DummyTask, + 'dummy-random': DummyRandomTask, +} diff --git a/poc-tests/modules/workflow_engine/workflow_processor.py b/poc-tests/modules/workflow_engine/workflow_processor.py index a16cc80423..bb9f7a4b40 100644 --- a/poc-tests/modules/workflow_engine/workflow_processor.py +++ b/poc-tests/modules/workflow_engine/workflow_processor.py @@ -5,65 +5,22 @@ import graphlib import concurrent.futures import time +import json import logging from itertools import product import yaml -from .task import Task, ProcessTask, DummyTask, DummyRandomTask +from .task import * +logger = (lambda: logging.getLogger())() -class WorkflowProcessor: - """Class for processing a workflow.""" +class WorkflowFile: + """Class for loading and processing a workflow file.""" + def __init__(self, workflow_file_path: str): + self.workflow_raw_data = self.__load_workflow(workflow_file_path) + self.task_collection = self.__process_workflow() + self.__static_workflow_validation() - def __init__(self, workflow_file_path: str, dry_run: bool, threads: int): - """ - Initialize WorkflowProcessor. - - Args: - workflow_file_path (str): Path to the workflow file (YAML format). - dry_run (bool): Display the plan without executing tasks. - threads (int): Number of threads to use for parallel execution. - """ - self.workflow_data = self.load_workflow(workflow_file_path) - self.tasks = self.workflow_data.get('tasks', []) - self.variables = self.workflow_data.get('variables', {}) - self.task_collection = self.process_workflow() - self.static_workflow_validation() - self.failed_tasks = set() - self.logger = self.setup_logger() - self.dry_run = dry_run - self.threads = threads - - def setup_logger(self, log_format: str = 'plain', log_level: str = 'INFO') -> logging.Logger: - """ - Set up the logger. - - Args: - log_format (str): Log format (plain or json). - log_level (str): Log level. - - Returns: - logging.Logger: Logger instance. - """ - logger = logging.getLogger(__name__) - logger.setLevel(log_level) - - # Clear existing handlers to avoid duplicates - for handler in logger.handlers: - logger.removeHandler(handler) - - if log_format == 'json': - formatter = logging.Formatter('{"time": "%(asctime)s", "level": "%(levelname)s", "message": "%(message)s", "tag": "%(tag)s"}', datefmt="%Y-%m-%d %H:%M:%S") - else: - formatter = logging.Formatter("[%(asctime)s] [%(levelname)s] %(message)s", datefmt="%Y-%m-%d %H:%M:%S") - - # Add a console handler - console_handler = logging.StreamHandler() - console_handler.setFormatter(formatter) - logger.addHandler(console_handler) - - return logger - - def load_workflow(self, file_path: str) -> dict: + def __load_workflow(self, file_path: str) -> dict: """ Load the workflow data from a file. @@ -76,7 +33,15 @@ def load_workflow(self, file_path: str) -> dict: with open(file_path, 'r', encoding='utf-8') as file: return yaml.safe_load(file) - def replace_placeholders(self, element: str, values: dict, parent_key: str = None): + def __process_workflow(self): + """Process the workflow and return a list of tasks.""" + task_collection = [] + variables = self.workflow_raw_data.get('variables', {}) + for task in self.workflow_raw_data.get('tasks', []): + task_collection.extend(self.__expand_task(task, variables)) + return task_collection + + def __replace_placeholders(self, element: str, values: dict, parent_key: str = None): """ Recursively replace placeholders in a dictionary or list. @@ -89,15 +54,14 @@ def replace_placeholders(self, element: str, values: dict, parent_key: str = Non Any: The processed element. """ if isinstance(element, dict): - return {key: self.replace_placeholders(value, values, key) for key, value in element.items()} - elif isinstance(element, list): - return [self.replace_placeholders(sub_element, values, parent_key) for sub_element in element] - elif isinstance(element, str): + return {key: self.__replace_placeholders(value, values, key) for key, value in element.items()} + if isinstance(element, list): + return [self.__replace_placeholders(sub_element, values, parent_key) for sub_element in element] + if isinstance(element, str): return element.format_map(values) - else: - return element + return element - def expand_task(self, task: dict, variables: dict): + def __expand_task(self, task: dict, variables: dict): """ Expand a task with variable values. @@ -120,25 +84,18 @@ def expand_task(self, task: dict, variables: dict): for combination in product(*variable_values): variables_with_items = {**variables, **dict(zip(as_identifiers, combination))} - expanded_tasks.append(self.replace_placeholders(task, variables_with_items)) + expanded_tasks.append(self.__replace_placeholders(task, variables_with_items)) else: - expanded_tasks.append(self.replace_placeholders(task, variables)) + expanded_tasks.append(self.__replace_placeholders(task, variables)) return expanded_tasks - def process_workflow(self): - """Process the workflow and return a list of tasks.""" - task_collection = [] - for task in self.tasks: - task_collection.extend(self.expand_task(task, self.variables)) - return task_collection - - def static_workflow_validation(self): + def __static_workflow_validation(self): """Validate the workflow against static criteria.""" def check_duplicated_tasks(self): """Validate task name duplication.""" task_name_counts = {task['task']: 0 for task in self.task_collection} - + for task in self.task_collection: task_name_counts[task['task']] += 1 @@ -160,7 +117,44 @@ def check_not_existing_tasks(self): for validation in validations: validation(self) - def build_dependency_graph(self, reverse=False): + +class DAG(): + """Class for creating a dependency graph.""" + def __init__(self, task_collection: list, reverse: bool = False): + self.task_collection = task_collection + self.reverse = reverse + self.dag, self.dependency_tree = self.__build_dag() + self.to_be_canceled = set() + self.finished_tasks_status = { + 'failed': set(), + 'canceled': set(), + 'successful': set(), + } + self.execution_plan = self.__create_execution_plan(self.dependency_tree) + self.dag.prepare() + + def is_active(self) -> bool: + """Check if the DAG is active.""" + return self.dag.is_active() + + def get_available_tasks(self) -> list: + """Get the available tasks.""" + return self.dag.get_ready() + + def get_execution_plan(self) -> dict: + """Get the execution plan.""" + return self.execution_plan + + def set_status(self, task_name: str, status: str): + """Set the status of a task.""" + self.finished_tasks_status[status].add(task_name) + self.dag.done(task_name) + + def should_be_canceled(self, task_name: str) -> bool: + """Check if a task should be canceled.""" + return task_name in self.to_be_canceled + + def __build_dag(self): """Build a dependency graph for the tasks.""" dependency_dict = {} dag = graphlib.TopologicalSorter() @@ -169,7 +163,7 @@ def build_dependency_graph(self, reverse=False): task_name = task['task'] dependencies = task.get('depends-on', []) - if reverse: + if self.reverse: for dependency in dependencies: dag.add(dependency, task_name) else: @@ -179,138 +173,153 @@ def build_dependency_graph(self, reverse=False): return dag, dependency_dict - def execute_task(self, task: dict, action) -> None: - """Execute a task.""" - task_name = task['task'] + def cancel_dependant_tasks(self, task_name, cancel_policy) -> None: + """Cancel all tasks that depend on a failed task.""" + def get_all_task_set(tasks): + task_set = set() - self.logger.info("Starting task", extra={'tag': task_name}) - start_time = time.time() - - try: - task_object = self.create_task_object(task, action) - task_object.execute() - # Pass the tag to the tag_formatter function if it exists - tag_info = self.logger.tag_formatter(task_name) if hasattr(self.logger, 'tag_formatter') else {} - self.logger.info("Finished task in %.2f seconds", time.time() - start_time, extra={'tag': task_name, **tag_info}) - except Exception as e: - # Pass the tag to the tag_formatter function if it exists - tag_info = self.logger.tag_formatter(task_name) if hasattr(self.logger, 'tag_formatter') else {} - self.logger.error("Task failed with error: %s", e, extra={'tag': task_name, **tag_info}) - self.failed_tasks.add(task_name) - # Handle the exception or re-raise if necessary - raise + for task, sub_tasks in tasks.items(): + task_set.add(task) + task_set.update(get_all_task_set(sub_tasks)) - def create_task_object(self, task: dict, action) -> Task: - """Create and return a Task object based on task type.""" - task_type = task[action]['this'] + return task_set - task_classes = { - 'process': ProcessTask, - 'dummy': DummyTask, - 'dummy-random': DummyRandomTask, - } + if cancel_policy == 'continue': + return - task_class = task_classes.get(task_type) + not_cancelled_tasks = self.finished_tasks_status['failed'].union(self.finished_tasks_status['successful']) + for root_task, sub_tasks in self.execution_plan.items(): + task_set = get_all_task_set({root_task: sub_tasks}) + if cancel_policy == 'abort-all': + self.to_be_canceled.update(task_set) + elif cancel_policy == 'abort-related-flows': + if task_name in task_set: + self.to_be_canceled.update(task_set - not_cancelled_tasks) + else: + raise ValueError(f"Unknown cancel policy '{cancel_policy}'.") - if task_class is not None: - return task_class(task['task'], task[action]['with'], self.logger) + def __create_execution_plan(self, dependency_dict: dict) -> dict: - raise ValueError(f"Unknown task type '{task_type}'.") + execution_plan = {} - def get_root_tasks(self, dependency_dict: dict) -> set: - """Get root tasks from the dependency dictionary.""" - all_tasks = set(dependency_dict.keys()) - dependent_tasks = set(dep for dependents in dependency_dict.values() for dep in dependents) - return all_tasks - dependent_tasks + def get_root_tasks(dependency_dict: dict) -> set: + """Get root tasks from the dependency dictionary.""" + all_tasks = set(dependency_dict.keys()) + dependent_tasks = set(dep for dependents in dependency_dict.values() for dep in dependents) + return all_tasks - dependent_tasks - def print_execution_plan(self, task_name: str, dependency_dict: dict, level: int = 0) -> None: - """Print the execution plan recursively.""" - if task_name not in dependency_dict: - return + def get_subtask_plan(task_name: str, dependency_dict: dict, level: int = 0) -> dict: + """Create the execution plan recursively as a dictionary.""" + if task_name not in dependency_dict: + return {task_name: {}} - dependencies = dependency_dict[task_name] - indentation = " " * level - self.logger.info("%s%s", indentation, task_name) + dependencies = dependency_dict[task_name] + plan = {task_name: {}} - for dependency in dependencies: - self.print_execution_plan(dependency, dependency_dict, level + 1) + for dependency in dependencies: + sub_plan = get_subtask_plan(dependency, dependency_dict, level + 1) + plan[task_name].update(sub_plan) - def execute_tasks_parallel(self) -> None: - """Execute tasks in parallel.""" - dag, dependency_dict = self.build_dependency_graph() + return plan - if self.dry_run: - # Display the execution plan without executing tasks - root_tasks = self.get_root_tasks(dependency_dict) - for root_task in root_tasks: - self.print_execution_plan(root_task, dependency_dict) - else: - dag.prepare() + root_tasks = get_root_tasks(dependency_dict) + for root_task in root_tasks: + execution_plan.update(get_subtask_plan(root_task, dependency_dict)) - with concurrent.futures.ThreadPoolExecutor(max_workers=self.threads) as executor: - futures = {} + return execution_plan - while True: - if not dag.is_active() or self.failed_tasks: - break - for task_name in dag.get_ready(): - dependencies = dependency_dict[task_name] +class WorkflowProcessor: + """Class for processing a workflow.""" - if any(dep in self.failed_tasks for dep in dependencies): - self.logger.info("[%s] Skipping task due to dependency failure", task_name) - self.failed_tasks.add(task_name) - dag.done(task_name) - continue + def __init__(self, workflow_file_path: str, dry_run: bool, threads: int): + """ + Initialize WorkflowProcessor. - dependent_futures = [futures[d] for d in dependencies if d in futures] + Args: + workflow_file_path (str): Path to the workflow file (YAML format). + dry_run (bool): Display the plan without executing tasks. + threads (int): Number of threads to use for parallel execution. + """ + self.task_collection = WorkflowFile(workflow_file_path).task_collection + self.dry_run = dry_run + self.threads = threads - concurrent.futures.wait(dependent_futures) + def execute_task(self, dag: DAG, task: dict, action) -> None: + """Execute a task.""" + task_name = task['task'] + if dag.should_be_canceled(task_name): + logger.warning("[%s] Skipping task due to dependency failure.", task_name) + dag.set_status(task_name, 'canceled') + else: + try: + task_object = self.create_task_object(task, action) - task = next(t for t in self.task_collection if t['task'] == task_name) - future = executor.submit(self.execute_task, task, 'do') - futures[task_name] = future + logger.info("[%s] Starting task.", task_name) + start_time = time.time() + task_object.execute() + logger.info("[%s] Finished task in %.2f seconds.", task_name, time.time() - start_time) + dag.set_status(task_name, 'successful') + except Exception as e: + # Pass the tag to the tag_formatter function if it exists + logger.error("[%s] Task failed with error: %s.", task_name, e) + dag.set_status(task_name, 'failed') + dag.cancel_dependant_tasks(task_name, task.get('on-error', 'abort-related-flows')) + # Handle the exception or re-raise if necessary + raise - dag.done(task_name) + def create_task_object(self, task: dict, action) -> Task: + """Create and return a Task object based on task type.""" + task_type = task[action]['this'] - # Wait for all tasks to complete - concurrent.futures.wait(futures.values()) + task_handler = TASKS_HANDLERS.get(task_type) - # Now execute tasks based on the reverse DAG - reverse_dag, reverse_dependency_dict = self.build_dependency_graph(reverse=True) + if task_handler is not None: + return task_handler(task['task'], task[action]['with']) - reverse_dag.prepare() + raise ValueError(f"Unknown task type '{task_type}'.") + def execute_tasks_parallel(self) -> None: + """Execute tasks in parallel.""" + if not self.dry_run: + logger.info("Executing tasks in parallel.") + dag = DAG(self.task_collection) + # Execute tasks based on the DAG with concurrent.futures.ThreadPoolExecutor(max_workers=self.threads) as executor: - reverse_futures = {} - + futures = {} while True: - if not reverse_dag.is_active() or self.failed_tasks: + if not dag.is_active(): break + for task_name in dag.get_available_tasks(): + task = next(t for t in self.task_collection if t['task'] == task_name) + future = executor.submit(self.execute_task, dag, task, 'do') + futures[task_name] = future + concurrent.futures.wait(futures.values()) - for task_name in reverse_dag.get_ready(): - dependencies = reverse_dependency_dict[task_name] - - if any(dep in self.failed_tasks for dep in dependencies): - self.logger.info("[%s] Skipping task due to dependency failure", task_name) - self.failed_tasks.add(task_name) - reverse_dag.done(task_name) - continue - - dependent_futures = [reverse_futures[d] for d in dependencies if d in reverse_futures] + # Now execute cleanup tasks based on the reverse DAG + reverse_dag = DAG(self.task_collection, reverse=True) - concurrent.futures.wait(dependent_futures) + logger.info("Executing cleanup tasks.") + with concurrent.futures.ThreadPoolExecutor(max_workers=self.threads) as executor: + reverse_futures = {} + while True: + if not reverse_dag.is_active(): + break + for task_name in reverse_dag.get_available_tasks(): task = next(t for t in self.task_collection if t['task'] == task_name) if 'cleanup' in task: - future = executor.submit(self.execute_task, task, 'cleanup') + future = executor.submit(self.execute_task, reverse_dag, task, 'cleanup') reverse_futures[task_name] = future - reverse_dag.done(task_name) + else: + reverse_dag.set_status(task_name, 'successful') + concurrent.futures.wait(reverse_futures.values()) - # Wait for all tasks to complete - concurrent.futures.wait(reverse_futures.values()) + else: + dag = DAG(self.task_collection) + logger.info("Execution plan: %s", json.dumps(dag.get_execution_plan(), indent=2)) - def main(self) -> None: + def run(self) -> None: """Main entry point.""" self.execute_tasks_parallel() @@ -320,6 +329,6 @@ def abort_execution(self, executor: concurrent.futures.ThreadPoolExecutor, futur try: _ = future.result() except Exception as e: - self.logger.error("Error in aborted task: %s", e) + logger.error("Error in aborted task: %s", e) executor.shutdown(wait=False) diff --git a/poc-tests/scripts/workflow_engine.py b/poc-tests/scripts/workflow_engine.py index 416a9df984..ff967e95e8 100644 --- a/poc-tests/scripts/workflow_engine.py +++ b/poc-tests/scripts/workflow_engine.py @@ -5,12 +5,15 @@ import os import sys import argparse +import logging +import colorlog project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..')) sys.path.append(project_root) from modules.workflow_engine.workflow_processor import WorkflowProcessor from modules.classes import SchemaValidator + def parse_arguments() -> argparse.Namespace: """Parse command line arguments.""" parser = argparse.ArgumentParser(description='Execute tasks in a workflow.') @@ -18,11 +21,17 @@ def parse_arguments() -> argparse.Namespace: parser.add_argument('schema_file', type=str, default="./schema.json", help='Path to the schema definition file.') parser.add_argument('--threads', type=int, default=1, help='Number of threads to use for parallel execution.') parser.add_argument('--dry-run', action='store_true', help='Display the plan without executing tasks.') - parser.add_argument('--log-format', choices=['plain', 'json'], default='plain', help='Log format (plain or json).') parser.add_argument('--log-level', choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'], default='INFO', help='Log level.') return parser.parse_args() +def setup_logger(log_level: str) -> None: + """Setup logger.""" + logger = logging.getLogger() + console_handler = colorlog.StreamHandler() + console_handler.setFormatter(colorlog.ColoredFormatter("%(log_color)s[%(asctime)s] [%(levelname)s] %(message)s", datefmt="%Y-%m-%d %H:%M:%S")) + logger.addHandler(console_handler) + logger.setLevel(log_level) def main() -> None: """Main entry point.""" @@ -31,10 +40,9 @@ def main() -> None: validator = SchemaValidator(args.schema_file, args.workflow_file) validator.preprocess_data() validator.validateSchema() - + setup_logger(args.log_level) processor = WorkflowProcessor(args.workflow_file, args.dry_run, args.threads) - processor.logger = processor.setup_logger(log_format=args.log_format, log_level=args.log_level) - processor.main() + processor.run() if __name__ == "__main__":