diff --git a/src/bluesearch/database/article.py b/src/bluesearch/database/article.py index 6c95360a1..c9e15dcfb 100644 --- a/src/bluesearch/database/article.py +++ b/src/bluesearch/database/article.py @@ -17,6 +17,7 @@ """Abstraction of scientific article data and related tools.""" from __future__ import annotations +import dataclasses import enum import hashlib import html @@ -25,10 +26,10 @@ import string import unicodedata from abc import ABC, abstractmethod -from dataclasses import dataclass +from dataclasses import dataclass, field from io import StringIO from pathlib import Path -from typing import IO, Generator, Iterable, Optional, Sequence, Tuple +from typing import Dict, IO, Generator, Iterable, List, Optional, Sequence, Tuple from xml.etree.ElementTree import Element # nosec from zipfile import ZipFile @@ -1071,6 +1072,7 @@ class Article(DataClassJSONMixin): arxiv_id: Optional[str] = None doi: Optional[str] = None uid: Optional[str] = None + topics: Dict[str, List[str]] = field(default_factory=dict) @classmethod def parse(cls, parser: ArticleParser) -> Article: diff --git a/src/bluesearch/entrypoint/database/parse.py b/src/bluesearch/entrypoint/database/parse.py index e2b4165ec..8b9337861 100644 --- a/src/bluesearch/entrypoint/database/parse.py +++ b/src/bluesearch/entrypoint/database/parse.py @@ -112,6 +112,14 @@ def init_parser(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: Parse files recursively. """, ) + parser.add_argument( + "-i", + "--include-topic", + action="store_true", + help=""" + If True, include topic inside the parsed json. + """, + ) parser.add_argument( "-n", "--dry-run", @@ -164,6 +172,7 @@ def run( output_dir: Path, match_filename: str | None, recursive: bool, + include_topic: bool, dry_run: bool, ) -> int: """Parse one or several articles. @@ -171,7 +180,9 @@ def run( Parameter description and potential defaults are documented inside of the `get_parser` function. """ - from bluesearch.utils import find_files + import json + + from bluesearch.utils import JSONL, find_files if input_path is None: if sys.stdin.isatty(): @@ -211,13 +222,27 @@ def run( try: parsers = iter_parsers(input_type, input_path) - for parser in parsers: + for i, parser in enumerate(parsers): article = Article.parse(parser) output_file = output_dir / f"{article.uid}.json" if output_file.exists(): raise FileExistsError(f"Output '{output_file}' already exists!") else: + + if include_topic: + topic_path = ( + input_path.parent.parent + / "topic" + / f"{input_path.stem}.json" + ) + topic_json = JSONL.load_jsonl(topic_path) + + if input_type == "pubmed-xml-set": + article.topics = topic_json[i]["topics"] + else: + article.topics = topic_json[0]["topics"] + serialized = article.to_json() output_file.write_text(serialized, "utf-8") diff --git a/src/bluesearch/entrypoint/database/run.py b/src/bluesearch/entrypoint/database/run.py index d2414331c..c6f9c2784 100644 --- a/src/bluesearch/entrypoint/database/run.py +++ b/src/bluesearch/entrypoint/database/run.py @@ -105,6 +105,7 @@ class GlobalParams(luigi.Config): """Global configuration.""" source = luigi.Parameter() + include_topic = luigi.Parameter(default=False) class DownloadTask(ExternalProgramTask): @@ -230,6 +231,9 @@ def program_args(self) -> list[str]: output_dir, ] + if GlobalParams().include_topic: + command.extend("--i") + if GlobalParams().source in {"medrxiv", "biorxiv"}: command.extend( ["-R", "-m", r".*\.meca$"], @@ -438,6 +442,9 @@ def program_args(self) -> list[str]: str(output_dir), ] + if GlobalParams().include_topic: + command.extend("--i") + return command diff --git a/src/bluesearch/entrypoint/database/topic_extract.py b/src/bluesearch/entrypoint/database/topic_extract.py index b045d364d..eef792634 100644 --- a/src/bluesearch/entrypoint/database/topic_extract.py +++ b/src/bluesearch/entrypoint/database/topic_extract.py @@ -19,6 +19,7 @@ import argparse import gzip +import json import logging from pathlib import Path from typing import Any @@ -78,6 +79,14 @@ def init_parser(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: expression. Ignored when 'input_path' is a path to a file. """, ) + parser.add_argument( + "-i", + "--inc-individual-json", + action="store_true", + help=""" + If True, individual json are also saved. + """, + ) parser.add_argument( "-R", "--recursive", @@ -121,12 +130,33 @@ def init_parser(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: return parser +def create_individual_json(path, topic_info): + """Create json containing the extracted topics. + + Parameters + ---------- + path + Path of the original article + topic_info + Topics extracted for the given article. + """ + folder = path.parent.parent / "topic" + if not folder.exists(): + folder.mkdir() + + new_path = folder / f"{path.stem}.json" + with new_path.open("w") as f: + line = json.dumps(topic_info.json()) + f.write(line) + + def run( *, source: str, input_path: Path, output_file: Path, match_filename: str | None, + inc_individual_json: bool, recursive: bool, overwrite: bool, dry_run: bool, @@ -182,6 +212,10 @@ def run( topic_info.add_journal_topics( "MeSH", mesh.resolve_parents(journal_topics, mesh_tree) ) + + if inc_individual_json: + create_individual_json(path, topic_info) + all_results.append(topic_info.json()) elif article_source is ArticleSource.PUBMED: if mesh_topic_db is None: @@ -194,6 +228,7 @@ def run( logger.info(f"Processing {path}") with gzip.open(path) as xml_stream: articles = ElementTree.parse(xml_stream) + topics_per_file = [] for i, article in enumerate(articles.iter("PubmedArticle")): logger.info(f"Processing element in file {i}") @@ -217,11 +252,23 @@ def run( "MeSH", mesh.resolve_parents(journal_topics, mesh_tree) ) all_results.append(topic_info.json()) + topics_per_file.append(topic_info.json()) + + folder = path.parent.parent / "topic" + if not folder.exists(): + folder.mkdir() + + new_path = folder / f"{path.stem}.json" + JSONL.dump_jsonl(topics_per_file, new_path) + elif article_source is ArticleSource.ARXIV: for path, article_topics in get_topics_for_arxiv_articles(inputs).items(): topic_info = TopicInfo(source=article_source, path=path) topic_info.add_article_topics("arXiv", article_topics) + if inc_individual_json: + create_individual_json(path, topic_info) + all_results.append(topic_info.json()) elif article_source in {ArticleSource.BIORXIV, ArticleSource.MEDRXIV}: for path in inputs: @@ -234,6 +281,9 @@ def run( topic_info = TopicInfo(source=ArticleSource(journal), path=path) topic_info.add_article_topics("Subject Area", [topic]) + if inc_individual_json: + create_individual_json(path, topic_info) + all_results.append(topic_info.json()) else: logger.error(f"The source type {source!r} is not implemented yet") diff --git a/tests/unit/entrypoint/database/test_topic_extract.py b/tests/unit/entrypoint/database/test_topic_extract.py index 2de8d94a3..9d01ed0da 100644 --- a/tests/unit/entrypoint/database/test_topic_extract.py +++ b/tests/unit/entrypoint/database/test_topic_extract.py @@ -35,6 +35,7 @@ "overwrite", "dry_run", "mesh_topic_db", + "inc_individual_json", } @@ -68,6 +69,7 @@ def test_input_path_not_correct(caplog): output_file=pathlib.Path(""), match_filename=None, recursive=False, + inc_individual_json=False, overwrite=False, dry_run=False, ) @@ -84,6 +86,7 @@ def test_source_type_not_implemented(test_data_path, caplog, tmp_path): output_file=tmp_path, match_filename=None, recursive=False, + inc_individual_json=False, overwrite=False, dry_run=False, ) @@ -99,6 +102,7 @@ def test_dry_run(test_data_path, capsys, tmp_path): output_file=tmp_path, match_filename=None, recursive=False, + inc_individual_json=False, overwrite=False, dry_run=True, ) @@ -131,6 +135,7 @@ def test_pmc_source(test_data_path, capsys, monkeypatch, tmp_path): output_file=output_jsonl, match_filename=None, recursive=False, + inc_individual_json=False, overwrite=False, dry_run=False, mesh_topic_db=mesh_tree_path, @@ -157,6 +162,7 @@ def test_pmc_source(test_data_path, capsys, monkeypatch, tmp_path): output_file=output_jsonl, match_filename=None, recursive=False, + inc_individual_json=False, overwrite=True, dry_run=False, mesh_topic_db=mesh_tree_path, @@ -172,6 +178,7 @@ def test_pmc_source(test_data_path, capsys, monkeypatch, tmp_path): output_file=output_jsonl, match_filename=None, recursive=False, + inc_individual_json=False, overwrite=False, dry_run=False, mesh_topic_db=mesh_tree_path, @@ -203,6 +210,7 @@ def test_medbiorxiv_source(capsys, monkeypatch, tmp_path, source): output_file=output_file, match_filename=None, recursive=False, + inc_individual_json=False, overwrite=False, dry_run=False, ) @@ -251,6 +259,7 @@ def test_pubmed_source( output_file=output_jsonl, match_filename=None, recursive=False, + inc_individual_json=False, overwrite=False, dry_run=False, mesh_topic_db=mesh_tree_path, @@ -285,6 +294,7 @@ def test_mesh_topic_db_is_enforced(source, caplog, tmp_path): output_file=tmp_path, match_filename=None, recursive=False, + inc_individual_json=False, overwrite=False, dry_run=False, ) diff --git a/tox.ini b/tox.ini index 33d65bdc6..75d576837 100644 --- a/tox.ini +++ b/tox.ini @@ -155,6 +155,7 @@ filterwarnings = ignore::DeprecationWarning:docker.*: ignore::DeprecationWarning:luigi.task: ignore::DeprecationWarning:transformers.image_utils.*: + ignore::Warning:luigi.parameter.UnconsumedParameterWarning: addopts = --cov --cov-config=tox.ini