Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add PubMedQA dataset #740

Merged
merged 26 commits into from
Jan 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
205 changes: 190 additions & 15 deletions pdm.lock

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ all = [
"imagesize>=1.4.1",
"scipy>=1.14.0",
"monai>=1.3.2",
"datasets>=3.2.0",
]

[project.scripts]
Expand Down
13 changes: 13 additions & 0 deletions src/eva/language/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
"""eva language API."""

try:
from eva.language.data import datasets
except ImportError as e:
msg = (
"eva language requirements are not installed.\n\n"
"Please pip install as follows:\n"
' python -m pip install "kaiko-eva[language]" --upgrade'
)
raise ImportError(str(e) + "\n\n" + msg) from e

__all__ = ["datasets"]
5 changes: 5 additions & 0 deletions src/eva/language/data/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
"""Language data API."""

from eva.language.data import datasets

__all__ = ["datasets"]
9 changes: 9 additions & 0 deletions src/eva/language/data/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
"""Language Datasets API."""

from eva.language.data.datasets.classification import PubMedQA
from eva.language.data.datasets.language import LanguageDataset

__all__ = [
"PubMedQA",
"LanguageDataset",
]
7 changes: 7 additions & 0 deletions src/eva/language/data/datasets/classification/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
"""Text classification datasets API."""

from eva.language.data.datasets.classification.pubmedqa import PubMedQA

__all__ = [
"PubMedQA",
]
63 changes: 63 additions & 0 deletions src/eva/language/data/datasets/classification/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
"""Base for text classification datasets."""

import abc
from typing import Any, Dict, List, Tuple

import torch
from typing_extensions import override

from eva.language.data.datasets.language import LanguageDataset


class TextClassification(LanguageDataset[Tuple[str, torch.Tensor]], abc.ABC):
"""Text classification abstract dataset."""

def __init__(self) -> None:
"""Initializes the text classification dataset."""
super().__init__()

@property
def classes(self) -> List[str] | None:
"""Returns list of class names."""

@property
def class_to_idx(self) -> Dict[str, int] | None:
"""Returns class name to index mapping."""

def load_metadata(self, index: int) -> Dict[str, Any] | None:
"""Returns the dataset metadata.

Args:
index: The index of the data sample.

Returns:
The sample metadata.
"""

@abc.abstractmethod
def load_text(self, index: int) -> str:
"""Returns the text content.

Args:
index: The index of the data sample.

Returns:
The text content.
"""
raise NotImplementedError

@abc.abstractmethod
def load_target(self, index: int) -> torch.Tensor:
"""Returns the target label.

Args:
index: The index of the data sample.

Returns:
The target label.
"""
raise NotImplementedError

@override
def __getitem__(self, index: int) -> Tuple[str, torch.Tensor, Dict[str, Any]]:
return (self.load_text(index), self.load_target(index), self.load_metadata(index) or {})
138 changes: 138 additions & 0 deletions src/eva/language/data/datasets/classification/pubmedqa.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
"""PubMedQA dataset class."""

import os
from typing import Any, Dict, List, Literal

import torch
from datasets import Dataset, load_dataset
from loguru import logger
from typing_extensions import override

from eva.language.data.datasets.classification import base


class PubMedQA(base.TextClassification):
"""Dataset class for PubMedQA question answering task."""

_license: str = "MIT License (https://github.com/pubmedqa/pubmedqa/blob/master/LICENSE)"
"""Dataset license."""

def __init__(
self,
root: str | None = None,
split: Literal["train", "val", "test"] | None = None,
download: bool = False,
) -> None:
"""Initialize the PubMedQA dataset.

Args:
root: Directory to cache the dataset. If None, no local caching is used.
split: Valid splits among ["train", "val", "test"].
If None, it will use "train+test+validation".
download: Whether to download the dataset if not found locally. Default is False.
"""
super().__init__()

self._root = root
self._split = split
self._download = download

def _load_dataset(self, dataset_cache_path: str | None) -> Dataset:
"""Loads the PubMedQA dataset from the local cache or downloads it if needed.

Args:
dataset_cache_path: The path to the local cache (may be None).

Returns:
The loaded Dataset object.
"""
if dataset_cache_path is not None and os.path.exists(dataset_cache_path):
dataset_path = dataset_cache_path
logger.info(f"Loaded dataset from local cache: {dataset_cache_path}")
is_local = True
else:
if not self._download and self._root:
raise ValueError(
"Dataset not found locally and downloading is disabled. "
"Set `download=True` or provide a valid local cache."
)
dataset_path = "bigbio/pubmed_qa"
is_local = False

if self._root:
logger.info(f"Dataset will be downloaded and cached in: {self._root}")
else:
logger.info("Using dataset directly from HuggingFace without caching.")

split = (self._split or "train+test+validation") if self._split != "val" else "validation"
raw_dataset = load_dataset(
dataset_path,
name="pubmed_qa_labeled_fold0_source",
split=split,
streaming=False,
cache_dir=self._root if (not is_local and self._root) else None,
)
if not isinstance(raw_dataset, Dataset):
raise TypeError(f"Expected a `Dataset`, but got {type(raw_dataset)}")

return raw_dataset

@override
def prepare_data(self) -> None:
"""Downloads and prepares the PubMedQA dataset.

If `self._root` is None, the dataset is used directly from HuggingFace.
Otherwise, it checks if the dataset is already cached in `self._root`.
If not cached, it downloads the dataset into `self._root`.
"""
dataset_cache_path = None

if self._root:
dataset_cache_path = os.path.join(self._root, "pubmed_qa")
os.makedirs(self._root, exist_ok=True)

try:
kurbanrita marked this conversation as resolved.
Show resolved Hide resolved
self.dataset = self._load_dataset(dataset_cache_path)
except Exception as e:
raise RuntimeError(f"Failed to prepare dataset: {e}") from e

@property
@override
def classes(self) -> List[str]:
return ["no", "yes", "maybe"]

@property
@override
def class_to_idx(self) -> Dict[str, int]:
return {"no": 0, "yes": 1, "maybe": 2}

@override
def load_text(self, index: int) -> str:
sample = dict(self.dataset[index])
return f"Question: {sample['QUESTION']}\nContext: {sample['CONTEXTS']}"

@override
def load_target(self, index: int) -> torch.Tensor:
return torch.tensor(
self.class_to_idx[self.dataset[index]["final_decision"]], dtype=torch.long
)

@override
def load_metadata(self, index: int) -> Dict[str, Any]:
sample = self.dataset[index]
return {
"year": sample["YEAR"],
"labels": sample["LABELS"],
"meshes": sample["MESHES"],
"long_answer": sample["LONG_ANSWER"],
"reasoning_required": sample["reasoning_required_pred"],
"reasoning_free": sample["reasoning_free_pred"],
}

@override
def __len__(self) -> int:
return len(self.dataset)

def _print_license(self) -> None:
"""Prints the dataset license."""
print(f"Dataset license: {self._license}")
13 changes: 13 additions & 0 deletions src/eva/language/data/datasets/language.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
"""Language Dataset base class."""

import abc
from typing import Generic, TypeVar

from eva.core.data.datasets import base

DataSample = TypeVar("DataSample")
"""The data sample type."""


class LanguageDataset(base.MapDataset, abc.ABC, Generic[DataSample]):
"""Base dataset class for text tasks."""
1 change: 1 addition & 0 deletions tests/eva/language/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""EVA language tests."""
1 change: 1 addition & 0 deletions tests/eva/language/data/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Language data tests."""
1 change: 1 addition & 0 deletions tests/eva/language/data/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Language datasets tests."""
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Tests for the text classification datasets."""
121 changes: 121 additions & 0 deletions tests/eva/language/data/datasets/classification/test_pubmedqa.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
"""PubMedQA dataset tests."""

import os
import shutil

import pytest
import torch
from datasets import Dataset

from eva.language.data import datasets


@pytest.mark.parametrize(
"split, expected_length",
[("train", 450), ("test", 500), ("val", 50), (None, 1000)],
)
def test_length(pubmedqa_dataset: datasets.PubMedQA, expected_length: int) -> None:
"""Tests the length of the dataset."""
assert len(pubmedqa_dataset) == expected_length


@pytest.mark.parametrize(
"split, index",
[
("train", 0),
("train", 10),
("test", 0),
("val", 0),
(None, 0),
],
)
def test_sample(pubmedqa_dataset: datasets.PubMedQA, index: int) -> None:
"""Tests the format of a dataset sample."""
sample = pubmedqa_dataset[index]
assert isinstance(sample, tuple)
assert len(sample) == 3

text, target, metadata = sample
assert isinstance(text, str)
assert text.startswith("Question: ")
assert "Context: " in text

assert isinstance(target, torch.Tensor)
assert target in [0, 1, 2]

assert isinstance(metadata, dict)
required_keys = {
"year",
"labels",
"meshes",
"long_answer",
"reasoning_required",
"reasoning_free",
}
assert all(key in metadata for key in required_keys)


@pytest.mark.parametrize("split", [None])
def test_classes(pubmedqa_dataset: datasets.PubMedQA) -> None:
"""Tests the dataset classes."""
assert pubmedqa_dataset.classes == ["no", "yes", "maybe"]
assert pubmedqa_dataset.class_to_idx == {"no": 0, "yes": 1, "maybe": 2}


@pytest.mark.parametrize("split", [None])
def test_prepare_data_no_root(pubmedqa_dataset: datasets.PubMedQA) -> None:
"""Tests dataset preparation without specifying a root directory."""
assert isinstance(pubmedqa_dataset.dataset, Dataset)
assert len(pubmedqa_dataset) > 0


@pytest.mark.parametrize("split", [None])
def test_prepare_data_with_cache(pubmedqa_dataset_with_cache: datasets.PubMedQA) -> None:
"""Tests dataset preparation with caching."""
pubmedqa_dataset_with_cache.prepare_data()
assert isinstance(pubmedqa_dataset_with_cache.dataset, Dataset)
assert len(pubmedqa_dataset_with_cache) > 0

cache_dir = pubmedqa_dataset_with_cache._root
if cache_dir:
assert os.path.exists(cache_dir)
assert any(os.scandir(cache_dir))


@pytest.mark.parametrize("split", [None])
def test_prepare_data_without_download(tmp_path, split) -> None:
"""Tests dataset preparation when download is disabled and cache is missing."""
root = tmp_path / "pubmed_qa_cache"
dataset = datasets.PubMedQA(root=str(root), split=split, download=False)

with pytest.raises(RuntimeError, match="Dataset not found locally and downloading is disabled"):
dataset.prepare_data()


def test_cleanup_cache(tmp_path) -> None:
"""Tests that the cache can be cleaned up."""
root = tmp_path / "pubmed_qa_cache"
dataset = datasets.PubMedQA(root=str(root), download=True)
dataset.prepare_data()

assert os.path.exists(root)

shutil.rmtree(root)
assert not os.path.exists(root)


@pytest.fixture(scope="function")
def pubmedqa_dataset(split: None) -> datasets.PubMedQA:
"""PubMedQA dataset fixture."""
dataset = datasets.PubMedQA(split=split)
dataset.prepare_data()
return dataset


@pytest.fixture(scope="function")
def pubmedqa_dataset_with_cache(tmp_path, split: None) -> datasets.PubMedQA:
"""PubMedQA dataset fixture with caching enabled."""
root = tmp_path / "pubmed_qa_cache"
dataset = datasets.PubMedQA(root=str(root), split=split, download=True)
dataset.prepare_data()
return dataset
Loading