From a7b2401deaf12f4b5fa5587d0d0024028eb7c5a8 Mon Sep 17 00:00:00 2001 From: Emilie Delattre Date: Tue, 15 Mar 2022 13:10:42 +0100 Subject: [PATCH] Remove requires/inherits decorator --- luigi.cfg | 23 ++ src/bluesearch/entrypoint/database/run.py | 249 ++++++---------------- 2 files changed, 91 insertions(+), 181 deletions(-) diff --git a/luigi.cfg b/luigi.cfg index 8830e6998..5f9baf652 100644 --- a/luigi.cfg +++ b/luigi.cfg @@ -2,3 +2,26 @@ autoload_range=true log_level = INFO local_scheduler = True + +[GlobalParams] + source=pubmed + +[DownloadTask] + from_month=2021-12 + output_dir=luigi-pipeline + identifier= + ; emtpy string is considered default value + +[TopicExtractTask] + mesh_topic_db=luigi-pipeline/mesh_topic_db.json + +[TopicFilterTask] + filter_config=luigi-pipeline/filter-config.jsonl + +[ConvertPDFTask] + grobid_host=0.0.0.0 + grobid_port=8070 + +[AddTask] + db_url=luigi-pipeline/my-db.db + db_type=sqlite \ No newline at end of file diff --git a/src/bluesearch/entrypoint/database/run.py b/src/bluesearch/entrypoint/database/run.py index 824ea85ad..bc86f4d7f 100644 --- a/src/bluesearch/entrypoint/database/run.py +++ b/src/bluesearch/entrypoint/database/run.py @@ -32,9 +32,6 @@ from defusedxml.ElementTree import tostring from luigi.contrib.external_program import ExternalProgramTask from luigi.tools.deps_tree import print_tree -from luigi.util import inherits, requires - -from bluesearch.database.article import ArticleSource logger = logging.getLogger(__name__) @@ -56,70 +53,19 @@ def init_parser(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: parser.description = "Run the overall pipeline." parser.add_argument( - "--source", - required=True, - type=str, - choices=[member.value for member in ArticleSource], - help="Source of the articles.", - ) - parser.add_argument( - "--from-month", - required=True, - type=str, - help="The starting month (included) for the download in format YYYY-MM. " - "All papers from the given month until today will be downloaded.", - ) - parser.add_argument( - "--filter-config", - required=True, - type=Path, - help=""" - Path to a .JSONL file that defines all the rules for filtering. - """, - ) - parser.add_argument( - "--output-dir", - required=True, - type=Path, - help=""" - Path to the output folder. All the results stored under - `output_dir/source/date` where date is concatenation of the - `from_month` and the day of execution of this command. - """, - ) - parser.add_argument( - "--db-url", - required=True, - type=str, - help=""" - The location of the database depending on the database type. - - For MySQL and MariaDB the server URL should be provided, for SQLite the - location of the database file. Generally, the scheme part of - the URL should be omitted, e.g. for MySQL the URL should be - of the form 'my_sql_server.ch:1234/my_database' and for SQLite - of the form '/path/to/the/local/database.db'. - """, - ) - parser.add_argument( - "--db-type", - default="sqlite", + "--final-task", type=str, - choices=("mariadb", "mysql", "postgres", "sqlite"), - help="Type of the database.", - ) - parser.add_argument( - "--mesh-topic-db", - type=Path, - help=""" - The JSON file with MeSH topic hierarchy information. Mandatory for - source types "pmc" and "pubmed". - - The JSON file should contain a flat dictionary with MeSH topic tree - numbers mapped to the corresponding topic labels. This file can be - produced using the `bbs_database parse-mesh-rdf` command. See that - command's description for more details. - """, + choices=( + "DownloadTask", + "UnzipTask", + "TopicExtractTask", + "TopicFilterTask", + "PerformFilteringTask", + "ConvertPDFTask", + "ParseTask", + "AddTask", + ), + help="Final task of the luigi pipeline.", ) parser.add_argument( "--dry-run", @@ -127,30 +73,6 @@ def init_parser(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: action="store_true", help="Prints out a diagram of the pipeline without running it.", ) - parser.add_argument( - "--grobid-host", - type=str, - help="The host of the GROBID server.", - ) - parser.add_argument( - "--grobid-port", - type=int, - help="The port of the GROBID server.", - ) - parser.add_argument( - "--identifier", - type=str, - help="""Custom name of the identifier. If not specified, we use - `from-month_today` - """, - ) - parser.add_argument( - "--final-task", - type=str, - help="""Name of the task where to manually stop the pipeline. Note - that the task itself will be included. - """, - ) return parser @@ -161,13 +83,18 @@ def init_parser(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: IDENTIFIER = None # make sure the same for all tasks +class GlobalParams(luigi.Config): + """Global configuration.""" + + source = luigi.Parameter() + + class DownloadTask(ExternalProgramTask): """Download raw files. They will be stored in the `raw/` folder. """ - source = luigi.Parameter() from_month = luigi.Parameter() output_dir = luigi.Parameter() identifier = luigi.OptionalParameter() @@ -186,7 +113,7 @@ def output(self) -> luigi.LocalTarget: else: identifier = IDENTIFIER - output_dir = Path(self.output_dir) / self.source / identifier / "raw" + output_dir = Path(self.output_dir) / GlobalParams().source / identifier / "raw" return luigi.LocalTarget(str(output_dir)) @@ -197,13 +124,12 @@ def program_args(self) -> list[str]: *BBS_BINARY, "download", *VERBOSITY, - self.source, + GlobalParams().source, self.from_month, output_dir, ] -@requires(DownloadTask) class UnzipTask(ExternalProgramTask): """Unzip raw files (if necessary). @@ -211,7 +137,10 @@ class UnzipTask(ExternalProgramTask): are stored inside of `raw_unzipped`. """ - source = luigi.Parameter() + @staticmethod + def requires() -> luigi.Task: + """Define dependency.""" + return DownloadTask() def output(self) -> luigi.LocalTarget: """Define unzipping folder.""" @@ -226,7 +155,7 @@ def run(self) -> None: output_dir = Path(self.output().path) # raw_unzipped output_dir.mkdir(exist_ok=True, parents=True) - if self.source == "pmc": + if GlobalParams().source == "pmc": # .tar.gz # We want collapse the folder hierarchy all_tar_files = input_dir.rglob("*.tar.gz") @@ -241,10 +170,9 @@ def run(self) -> None: shutil.copyfileobj(f_in, f_out) # type: ignore else: - raise ValueError(f"Unsupported source {self.source}") + raise ValueError(f"Unsupported source {GlobalParams().source}") -@inherits(DownloadTask, UnzipTask) class TopicExtractTask(ExternalProgramTask): """Topic extraction. @@ -253,15 +181,15 @@ class TopicExtractTask(ExternalProgramTask): `topic_infos.jsonl`. """ - source = luigi.Parameter() mesh_topic_db = luigi.Parameter() - def requires(self) -> luigi.Task: + @staticmethod + def requires() -> luigi.Task: """Define conditional dependencies.""" - if self.source in {"pmc"}: - return self.clone(UnzipTask) + if GlobalParams().source in {"pmc"}: + return UnzipTask() else: - return self.clone(DownloadTask) + return DownloadTask() def output(self) -> luigi.LocalTarget: """Define output file path.""" @@ -279,20 +207,20 @@ def program_args(self) -> list[str]: *BBS_BINARY, "topic-extract", *VERBOSITY, - self.source, + GlobalParams().source, input_dir, output_dir, ] - if self.source in {"medrxiv", "biorxiv"}: + if GlobalParams().source in {"medrxiv", "biorxiv"}: command.extend( ["-R", "-m", r".*\.meca$"], ) - if self.source in {"pmc", "pubmed"}: + if GlobalParams().source in {"pmc", "pubmed"}: command.append(f"--mesh-topic-db={self.mesh_topic_db}") - if self.source == "pubmed": + if GlobalParams().source == "pubmed": command.extend( ["-R", "-m", r".*\.xml\.gz$"], ) @@ -300,7 +228,6 @@ def program_args(self) -> list[str]: return command -@requires(TopicExtractTask) class TopicFilterTask(ExternalProgramTask): """Run topic filtering entrypoint. @@ -310,6 +237,11 @@ class TopicFilterTask(ExternalProgramTask): filter_config = luigi.Parameter() + @staticmethod + def requires() -> luigi.Task: + """Define dependency.""" + return TopicExtractTask() + def output(self) -> luigi.LocalTarget: """Define output file.""" output_file = Path(self.input().path).parent / "filtering.csv" @@ -333,7 +265,6 @@ def program_args(self) -> list[str]: return command -@requires(TopicFilterTask) class PerformFilteringTask(luigi.Task): """Create folder that only contains relevant articles. @@ -341,6 +272,11 @@ class PerformFilteringTask(luigi.Task): stage. The only input is the `filtering.csv`. """ + @staticmethod + def requires() -> luigi.Task: + """Define dependency.""" + return TopicFilterTask() + def output(self) -> luigi.LocalTarget: """Define output folder.""" output_dir = Path(self.input().path).parent / "filtered" @@ -355,7 +291,7 @@ def run(self) -> None: output_dir.mkdir(exist_ok=True) - if self.source == "pubmed": + if GlobalParams().source == "pubmed": # Find all input files (.xml.gz) all_input_files = [Path(p) for p in filtering["path"].unique()] @@ -396,7 +332,6 @@ def create_symlink(path): accepted.apply(create_symlink) -@requires(PerformFilteringTask) class ConvertPDFTask(ExternalProgramTask): """Convert PDFs to XMLs. @@ -407,6 +342,11 @@ class ConvertPDFTask(ExternalProgramTask): grobid_host = luigi.Parameter() grobid_port = luigi.IntParameter() + @staticmethod + def requires() -> luigi.Task: + """Define dependency.""" + return PerformFilteringTask() + def program_args(self) -> list[str]: """Define subprocess arguments.""" input_dir = Path(self.input().path).parent / "filtered" @@ -431,7 +371,6 @@ def output(self) -> luigi.LocalTarget: return luigi.LocalTarget(str(output_file)) -@inherits(ConvertPDFTask, PerformFilteringTask) class ParseTask(ExternalProgramTask): """Parse articles. @@ -439,12 +378,13 @@ class ParseTask(ExternalProgramTask): `source="arxiv"` `converted_pdfs/`). """ - def requires(self) -> luigi.Task: + @staticmethod + def requires() -> luigi.Task: """Define conditional dependencies.""" - if self.source == "arxiv": - return self.clone(ConvertPDFTask) + if GlobalParams().source == "arxiv": + return ConvertPDFTask() else: - return self.clone(PerformFilteringTask) + return PerformFilteringTask() def output(self) -> luigi.LocalTarget: """Define output folder.""" @@ -469,7 +409,7 @@ def program_args(self) -> list[str]: "pmc": "jats-xml", "pubmed": "pubmed-xml-set", } - parser = source2parser[self.source] + parser = source2parser[GlobalParams().source] command = [ *BBS_BINARY, @@ -483,7 +423,6 @@ def program_args(self) -> list[str]: return command -@requires(ParseTask) class AddTask(ExternalProgramTask): """Add parsed articles to the database. @@ -494,6 +433,11 @@ class AddTask(ExternalProgramTask): db_url = luigi.Parameter() db_type = luigi.Parameter() + @staticmethod + def requires() -> luigi.Task: + """Define dependency.""" + return ParseTask() + def complete(self) -> bool: """Check if all articles inside of `parsed/` are in the database.""" # If all the articles are inside @@ -542,45 +486,10 @@ def program_args(self) -> list[str]: return command -def get_all_dependencies(task: luigi.Task) -> set[luigi.Task]: - """Get all dependencies of a given task. - - Parameters - ---------- - task - Input task - - Returns - ------- - set[luigi.Task] - All the tasks that the `input` depends on including itself. - """ - current_deps = set(task.deps()) - if not current_deps: - return set() - - else: - deps = {task} - for current_dep in current_deps: - deps |= get_all_dependencies(current_dep) - - return deps | current_deps - - def run( *, - source: str, - from_month: str, - filter_config: Path, - output_dir: Path, - db_url: str, - db_type: str, - mesh_topic_db: Path | None, - dry_run: bool, - grobid_host: str | None, - grobid_port: int | None, - identifier: str | None, final_task: str | None, + dry_run: bool, ) -> int: """Run overall pipeline. @@ -598,36 +507,14 @@ def run( ParseTask.capture_output = CAPTURE_OUTPUT AddTask.capture_output = CAPTURE_OUTPUT - add_task_inst = AddTask( - source=source, - from_month=from_month, - filter_config=str(filter_config), - output_dir=str(output_dir), - mesh_topic_db=str(mesh_topic_db), - grobid_host=grobid_host, - grobid_port=grobid_port, - db_url=db_url, - db_type=db_type, - identifier=identifier, - ) - if final_task is None: - selected_task_inst = add_task_inst + if final_task: + final_task = globals()[final_task] else: - all_dependencies = get_all_dependencies(add_task_inst) - all_dependencies_map = {t.__class__.__name__: t for t in all_dependencies} + final_task = AddTask - if final_task in all_dependencies_map: - selected_task_inst = all_dependencies_map[final_task] - else: - raise ValueError(f"Unrecognized final task {final_task}") - - luigi_kwargs = { - "tasks": [selected_task_inst], - } if dry_run: - print(print_tree(selected_task_inst, last=False)) + print(print_tree(final_task(), last=False)) else: - - luigi.build(**luigi_kwargs) + luigi.build([final_task()]) return 0