Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add PubMedQA dataset #740

Merged
merged 26 commits into from
Jan 21, 2025
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
205 changes: 190 additions & 15 deletions pdm.lock

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ all = [
"imagesize>=1.4.1",
"scipy>=1.14.0",
"monai>=1.3.2",
"datasets>=3.2.0",
]

[project.scripts]
Expand Down
13 changes: 13 additions & 0 deletions src/eva/language/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
"""eva language API."""

try:
from eva.language.data import datasets
except ImportError as e:
msg = (
"eva language requirements are not installed.\n\n"
"Please pip install as follows:\n"
' python -m pip install "kaiko-eva[language]" --upgrade'
)
raise ImportError(str(e) + "\n\n" + msg) from e

__all__ = ["datasets"]
5 changes: 5 additions & 0 deletions src/eva/language/data/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
"""Language data API."""

from eva.language.data import datasets

__all__ = ["datasets"]
9 changes: 9 additions & 0 deletions src/eva/language/data/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
"""Language Datasets API."""

from eva.language.data.datasets.classification import PubMedQA
from eva.language.data.datasets.language import LanguageDataset

__all__ = [
"PubMedQA",
"LanguageDataset",
]
7 changes: 7 additions & 0 deletions src/eva/language/data/datasets/classification/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
"""Text classification datasets API."""

from eva.language.data.datasets.classification.pubmedqa import PubMedQA

__all__ = [
"PubMedQA",
]
66 changes: 66 additions & 0 deletions src/eva/language/data/datasets/classification/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
"""Base for text classification datasets."""

import abc
from typing import Any, Dict, List, Tuple

import torch
from typing_extensions import override

from eva.language.data.datasets.language import LanguageDataset


class TextClassification(LanguageDataset[Tuple[str, torch.Tensor]], abc.ABC):
"""Text classification abstract dataset."""

def __init__(self) -> None:
"""Initializes the text classification dataset."""
super().__init__()

@property
def classes(self) -> List[str] | None:
"""Returns list of class names."""
return None
kurbanrita marked this conversation as resolved.
Show resolved Hide resolved

@property
def class_to_idx(self) -> Dict[str, int] | None:
"""Returns class name to index mapping."""
return None

def load_metadata(self, index: int) -> Dict[str, Any] | None:
"""Returns the dataset metadata.

Args:
index: The index of the data sample.

Returns:
The sample metadata.
"""
return None

@abc.abstractmethod
def load_text(self, index: int) -> str:
"""Returns the text content.

Args:
index: The index of the data sample.

Returns:
The text content.
"""
raise NotImplementedError

@abc.abstractmethod
def load_target(self, index: int) -> torch.Tensor:
"""Returns the target label.

Args:
index: The index of the data sample.

Returns:
The target label.
"""
raise NotImplementedError

@override
def __getitem__(self, index: int) -> Tuple[str, torch.Tensor, Dict[str, Any]]:
return (self.load_text(index), self.load_target(index), self.load_metadata(index) or {})
126 changes: 126 additions & 0 deletions src/eva/language/data/datasets/classification/pubmedqa.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
"""PubMedQA dataset class."""

import os
from typing import Any, Dict, List

import torch
from datasets import Dataset, load_dataset
from typing_extensions import override

from eva.language.data.datasets.classification import base


class PubMedQA(base.TextClassification):
"""Dataset class for PubMedQA question answering task."""

_license: str = "MIT License (https://github.com/pubmedqa/pubmedqa/blob/master/LICENSE)"
"""Dataset license."""

def __init__(
self,
root: str | None = None,
split: str | None = "train+test+validation",
kurbanrita marked this conversation as resolved.
Show resolved Hide resolved
download: bool = False,
) -> None:
"""Initialize the PubMedQA dataset.

Args:
root: Directory to cache the dataset. If None, no local caching is used.
split: Dataset split to use. Default is "train+test+validation".
download: Whether to download the dataset if not found locally. Default is False.
"""
super().__init__()
self._root = root
kurbanrita marked this conversation as resolved.
Show resolved Hide resolved
self._split = split
self._download = download

@override
def prepare_data(self) -> None:
"""Downloads and prepares the PubMedQA dataset.

If `self._root` is None, the dataset is used directly from HuggingFace.
Otherwise, it checks if the dataset is already cached in `self._root`.
If not cached, it downloads the dataset into `self._root`.
"""
dataset_cache_path = None

if self._root:
dataset_cache_path = os.path.join(self._root, "pubmed_qa")
os.makedirs(self._root, exist_ok=True)

try:
kurbanrita marked this conversation as resolved.
Show resolved Hide resolved
if dataset_cache_path and os.path.exists(dataset_cache_path):
raw_dataset = load_dataset(
dataset_cache_path,
name="pubmed_qa_labeled_fold0_source",
split=self._split,
streaming=False,
)
print(f"Loaded dataset from local cache: {dataset_cache_path}")
else:
if not self._download and self._root:
raise ValueError(
"Dataset not found locally and downloading is disabled. "
"Set `download=True` or provide a valid local cache."
)

raw_dataset = load_dataset(
"bigbio/pubmed_qa",
name="pubmed_qa_labeled_fold0_source",
split=self._split,
cache_dir=self._root if self._root else None,
streaming=False,
)
if self._root:
print(f"Dataset downloaded and cached in: {self._root}")
else:
print("Using dataset directly from Hugging Face without caching.")
kurbanrita marked this conversation as resolved.
Show resolved Hide resolved

if not isinstance(raw_dataset, Dataset):
raise TypeError(f"Expected a `Dataset`, but got {type(raw_dataset)}")

self.dataset: Dataset = raw_dataset

except Exception as e:
raise RuntimeError(f"Failed to prepare dataset: {e}") from e
kurbanrita marked this conversation as resolved.
Show resolved Hide resolved

@property
@override
def classes(self) -> List[str]:
return ["no", "yes", "maybe"]

@property
@override
def class_to_idx(self) -> Dict[str, int]:
return {"no": 0, "yes": 1, "maybe": 2}

@override
def load_text(self, index: int) -> str:
sample = dict(self.dataset[index])
return f"Question: {sample['QUESTION']}\nContext: {sample['CONTEXTS']}"

@override
def load_target(self, index: int) -> torch.Tensor:
return torch.tensor(
self.class_to_idx[self.dataset[index]["final_decision"]], dtype=torch.long
)

@override
def load_metadata(self, index: int) -> Dict[str, Any]:
sample = self.dataset[index]
return {
"year": sample["YEAR"],
"labels": sample["LABELS"],
"meshes": sample["MESHES"],
"long_answer": sample["LONG_ANSWER"],
"reasoning_required": sample["reasoning_required_pred"],
"reasoning_free": sample["reasoning_free_pred"],
}

@override
def __len__(self) -> int:
return len(self.dataset)

def _print_license(self) -> None:
"""Prints the dataset license."""
print(f"Dataset license: {self._license}")
15 changes: 15 additions & 0 deletions src/eva/language/data/datasets/language.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
"""Language Dataset base class."""

import abc
from typing import Generic, TypeVar

from eva.core.data.datasets import base

DataSample = TypeVar("DataSample")
"""The data sample type."""


class LanguageDataset(base.MapDataset, abc.ABC, Generic[DataSample]):
"""Base dataset class for text tasks."""

pass
kurbanrita marked this conversation as resolved.
Show resolved Hide resolved
1 change: 1 addition & 0 deletions tests/eva/text/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""EVA language tests."""
1 change: 1 addition & 0 deletions tests/eva/text/data/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Language data tests."""
1 change: 1 addition & 0 deletions tests/eva/text/data/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Language datasets tests."""
1 change: 1 addition & 0 deletions tests/eva/text/data/datasets/classification/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Tests for the text classification datasets."""
128 changes: 128 additions & 0 deletions tests/eva/text/data/datasets/classification/test_pubmedqa.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
"""PubMedQA dataset tests."""

import os
import shutil
from typing import Literal

import pytest
import torch
from datasets import Dataset

from eva.language.data import datasets


@pytest.fixture(scope="function")
def pubmedqa_dataset(
split: Literal["train", "test", "validation", "train+test+validation"]
) -> datasets.PubMedQA:
"""PubMedQA dataset fixture."""
dataset = datasets.PubMedQA(split=split)
dataset.prepare_data()
return dataset


@pytest.fixture(scope="function")
def pubmedqa_dataset_with_cache(
tmp_path, split: Literal["train", "test", "validation", "train+test+validation"]
) -> datasets.PubMedQA:
"""PubMedQA dataset fixture with caching enabled."""
root = tmp_path / "pubmed_qa_cache"
dataset = datasets.PubMedQA(root=str(root), split=split, download=True)
dataset.prepare_data()
return dataset


@pytest.mark.parametrize(
"split, expected_length",
[("train", 450), ("test", 500), ("validation", 50), ("train+test+validation", 1000)],
)
def test_length(pubmedqa_dataset: datasets.PubMedQA, expected_length: int) -> None:
"""Tests the length of the dataset."""
assert len(pubmedqa_dataset) == expected_length


@pytest.mark.parametrize(
"split, index",
[
("train", 0),
("train", 10),
("test", 0),
("validation", 0),
("train+test+validation", 0),
],
)
def test_sample(pubmedqa_dataset: datasets.PubMedQA, index: int) -> None:
"""Tests the format of a dataset sample."""
sample = pubmedqa_dataset[index]
assert isinstance(sample, tuple)
assert len(sample) == 3

text, target, metadata = sample
assert isinstance(text, str)
assert text.startswith("Question: ")
assert "Context: " in text

assert isinstance(target, torch.Tensor)
assert target in [0, 1, 2]

assert isinstance(metadata, dict)
required_keys = {
"year",
"labels",
"meshes",
"long_answer",
"reasoning_required",
"reasoning_free",
}
assert all(key in metadata for key in required_keys)


@pytest.mark.parametrize("split", ["train", "test", "validation", "train+test+validation"])
def test_classes(pubmedqa_dataset: datasets.PubMedQA) -> None:
"""Tests the dataset classes."""
assert pubmedqa_dataset.classes == ["no", "yes", "maybe"]
assert pubmedqa_dataset.class_to_idx == {"no": 0, "yes": 1, "maybe": 2}


@pytest.mark.parametrize("split", ["train+test+validation"])
def test_prepare_data_no_root(pubmedqa_dataset: datasets.PubMedQA) -> None:
"""Tests dataset preparation without specifying a root directory."""
assert isinstance(pubmedqa_dataset.dataset, Dataset)
assert len(pubmedqa_dataset) > 0


@pytest.mark.parametrize("split", ["train+test+validation"])
def test_prepare_data_with_cache(pubmedqa_dataset_with_cache: datasets.PubMedQA) -> None:
"""Tests dataset preparation with caching."""
pubmedqa_dataset_with_cache.prepare_data()
assert isinstance(pubmedqa_dataset_with_cache.dataset, Dataset)
assert len(pubmedqa_dataset_with_cache) > 0

cache_dir = pubmedqa_dataset_with_cache._root
if cache_dir:
assert os.path.exists(cache_dir)
assert any(os.scandir(cache_dir))


@pytest.mark.parametrize("split", ["train+test+validation"])
def test_prepare_data_without_download(
tmp_path, split: Literal["train", "test", "validation", "train+test+validation"]
) -> None:
"""Tests dataset preparation when download is disabled and cache is missing."""
root = tmp_path / "pubmed_qa_cache"
dataset = datasets.PubMedQA(root=str(root), split=split, download=False)

with pytest.raises(RuntimeError, match="Dataset not found locally and downloading is disabled"):
dataset.prepare_data()


def test_cleanup_cache(tmp_path) -> None:
"""Tests that the cache can be cleaned up."""
root = tmp_path / "pubmed_qa_cache"
dataset = datasets.PubMedQA(root=str(root), split="train+test+validation", download=True)
dataset.prepare_data()

assert os.path.exists(root)

shutil.rmtree(root)
assert not os.path.exists(root)
Loading