Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] enable base_dir to be a list #392

Merged
merged 3 commits into from
Aug 17, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 34 additions & 22 deletions src/alpaca_eval/annotators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import numpy as np
import pandas as pd

from .. import completion_parsers, constants, processors, utils
from .. import completion_parsers, constants, processors, types, utils
from ..decoders import get_fn_completions

CURRENT_DIR = Path(__file__).parent
Expand All @@ -23,11 +23,11 @@ class BaseAnnotator(abc.ABC):

Parameters
----------
annotators_config : Path or list of dict, optional
A dictionary or path to a yaml file containing the configuration for the pool of annotators. If a directory,
we search for 'configs.yaml' in it. The keys in the first dictionary should be the annotator's name, and
the value should be a dictionary of the annotator's configuration which should have the following keys:
The path is relative to `base_dir` directory.
annotators_config : Path, optional
A path to a yaml file containing the configuration for the pool of annotators. The path can be absolute or
relative to `base_dir` directory. If a directory, we search for 'configs.yaml' in it. After loading, the keys
in the first dictionary should be the annotator's name, and the value should be a dictionary of the annotator's
configuration which should have the following keys:
- prompt_template (str): a prompt template or path to it. The template should contain placeholders for keys in
the example dictionary, typically {instruction} and {output_1} {output_2}.
- fn_completions (str): function in `alpaca_farm.decoders` for completions. Needs to accept as first argument
Expand Down Expand Up @@ -58,9 +58,10 @@ class BaseAnnotator(abc.ABC):
is_store_missing_annotations : bool, optional
Whether to store missing annotations. If True it avoids trying to reannotate examples that have errors.

base_dir : Path, optional
base_dir : Path or list of Path, optional
Path to the directory containing the annotators configs. I.e. annotators_config will be relative
to this directory. If None uses self.DEFAULT_BASE_DIR
to this directory. If None uses self.DEFAULT_BASE_DIR. If a list we will use the first such that
annotators_config can be loaded.

is_raise_if_missing_primary_keys : bool, optional
Whether to ensure that the primary keys are in the example dictionary. If True, raises an error.
Expand All @@ -85,7 +86,7 @@ class BaseAnnotator(abc.ABC):
def __init__(
self,
primary_keys: Sequence[str],
annotators_config: Union[utils.AnyPath, list[dict[str, Any]]] = constants.DEFAULT_ANNOTATOR_CONFIG,
annotators_config: Union[types.AnyPath] = constants.DEFAULT_ANNOTATOR_CONFIG,
seed: Optional[int] = 0,
is_avoid_reannotations: bool = True,
other_output_keys_to_keep: Sequence[str] = (
Expand All @@ -95,13 +96,13 @@ def __init__(
),
other_input_keys_to_keep: Sequence[str] = (),
is_store_missing_annotations: bool = True,
base_dir: Optional[utils.AnyPath] = None,
base_dir: Optional[Union[types.AnyPath, Sequence[types.AnyPath]]] = None,
is_raise_if_missing_primary_keys: bool = True,
annotation_type: Optional[Type] = None,
is_reapply_parsing: bool = False,
):
logging.info(f"Creating the annotator from `{annotators_config}`.")
self.base_dir = Path(base_dir or self.DEFAULT_BASE_DIR)
base_dir = base_dir or self.DEFAULT_BASE_DIR
self.seed = seed
self.is_avoid_reannotations = is_avoid_reannotations
self.primary_keys = list(primary_keys)
Expand All @@ -113,7 +114,15 @@ def __init__(
self.annotation_type = annotation_type or self.DEFAULT_ANNOTATION_TYPE
self.is_reapply_parsing = is_reapply_parsing

self.annotators_config = self._initialize_annotators_config(annotators_config)
# loop over all the base_dirs until you find the annotators_config
if not isinstance(base_dir, (list, tuple, set)):
base_dir = [base_dir]
for d in base_dir:
self.base_dir = Path(d)
self.annotators_config = self._initialize_annotators_config(annotators_config)
if self.annotators_config.exists():
break

self.annotators = self._initialize_annotators()
self.df_annotations = None

Expand Down Expand Up @@ -151,7 +160,7 @@ def annotator_name(self) -> str:

def __call__(
self,
to_annotate: utils.AnyData,
to_annotate: types.AnyData,
chunksize: Optional[int] = 128,
**decoding_kwargs,
) -> list[dict[str, Any]]:
Expand Down Expand Up @@ -207,6 +216,9 @@ def __call__(

### Private methods ###
def _initialize_annotators_config(self, annotators_config):
if isinstance(annotators_config, (list, tuple)):
return annotators_config

# setting it relative to the config directory
annotators_config = self.base_dir / annotators_config

Expand Down Expand Up @@ -243,7 +255,7 @@ def _add_missing_primary_keys_(self, df: pd.DataFrame):
for c in missing_primary_keys:
df[c] = None

def _preprocess(self, to_annotate: utils.AnyData) -> pd.DataFrame:
def _preprocess(self, to_annotate: types.AnyData) -> pd.DataFrame:
"""Preprocess the examples to annotate. In particular takes care of filtering unnecessary examples."""

df_to_annotate = utils.convert_to_dataframe(to_annotate)
Expand Down Expand Up @@ -316,7 +328,7 @@ def _annotate(self, df_to_annotate: pd.DataFrame, **decoding_kwargs) -> pd.DataF
def _postprocess_and_store_(
self,
df_annotated: pd.DataFrame,
to_annotate: utils.AnyData,
to_annotate: types.AnyData,
) -> list[dict[str, Any]]:
"""Convert the dataframe into a list of dictionaries to be returned, and store current anntations."""

Expand Down Expand Up @@ -476,11 +488,11 @@ class BaseAnnotatorJSON(BaseAnnotator):
"""
)

def __init__(self, *args, caching_path: Optional[utils.AnyPath] = "auto", **kwargs):
def __init__(self, *args, caching_path: Optional[types.AnyPath] = "auto", **kwargs):
super().__init__(*args, **kwargs)
self.caching_path = self._initialize_cache(caching_path)

def save(self, path: Optional[utils.AnyPath] = None):
def save(self, path: Optional[types.AnyPath] = None):
"""Save all annotations to json."""

path = path or self.caching_path
Expand All @@ -492,7 +504,7 @@ def save(self, path: Optional[utils.AnyPath] = None):
self.df_annotations = self.df_annotations[~self.df_annotations[self.annotation_key].isna()]
self.df_annotations.to_json(path, orient="records", indent=2)

def load_(self, path: Optional[utils.AnyPath] = None):
def load_(self, path: Optional[types.AnyPath] = None):
"""Load all the annotations from json."""
path = path or self.caching_path
if path is not None:
Expand Down Expand Up @@ -591,15 +603,15 @@ class SingleAnnotator:

def __init__(
self,
prompt_template: utils.AnyPath,
fn_completion_parser: Optional[Union[Callable, str]] = "regex_parser",
prompt_template: types.AnyPath,
fn_completion_parser: Optional[Union[Callable, str]] = None,
completion_parser_kwargs: Optional[dict[str, Any]] = None,
fn_completions: Union[Callable, str] = "openai_completions",
completions_kwargs: Optional[dict[str, Any]] = None,
is_shuffle: bool = True,
seed: Optional[int] = 123,
batch_size: int = 1,
base_dir: utils.AnyPath = constants.EVALUATORS_CONFIG_DIR,
base_dir: types.AnyPath = constants.EVALUATORS_CONFIG_DIR,
annotation_column: str = "annotation",
is_store_raw_completions: bool = True,
processors_to_kwargs: Optional[dict[str, dict]] = None,
Expand Down Expand Up @@ -719,7 +731,7 @@ def _search_processor(self, name: Union[str, Type["processors.BaseProcessor"]])
assert issubclass(name, processors.BaseProcessor)
return name

def _get_prompt_template(self, prompt_template: utils.AnyPath):
def _get_prompt_template(self, prompt_template: types.AnyPath):
return utils.read_or_return(self.base_dir / prompt_template)

def _make_prompts(
Expand Down
Loading