From 122f96559bdf69304d018fbb0a90d989b1641d1e Mon Sep 17 00:00:00 2001 From: Milo Cress Date: Mon, 12 Feb 2024 12:21:11 -0500 Subject: [PATCH 1/4] strengthened chat formatting validation (#960) * strengthened chat formatting validation * fix types * made assert messages more descriptive * used raise instead of assert, added type checks * added list type check * type error if no string content * add test case for new validation * relaxed type constraints to interface minimum * use Mapping and Iterable * fix mapping in type aliases too * iterable -> sequence * sequence -> list * Mapping -> Dict * use mapping again * fixed another one * updated message * factored out duplicate functions * dict -> mapping * add sequence --- llmfoundry/data/finetuning/tasks.py | 70 ++++++++++++++++++++---- tests/data/test_template_tokenization.py | 4 +- 2 files changed, 62 insertions(+), 12 deletions(-) diff --git a/llmfoundry/data/finetuning/tasks.py b/llmfoundry/data/finetuning/tasks.py index 7f2a5417b4..126ed43812 100644 --- a/llmfoundry/data/finetuning/tasks.py +++ b/llmfoundry/data/finetuning/tasks.py @@ -35,9 +35,10 @@ def preprocessing_fn(example: Dict) -> Dict[str, str]: import logging import os import warnings +from collections.abc import Mapping from functools import partial from pathlib import Path -from typing import (Any, Callable, Dict, List, Literal, Optional, Sequence, +from typing import (Any, Callable, Dict, List, Literal, Optional, Sequence, Set, Tuple, Union, cast) import datasets as hf_datasets @@ -55,13 +56,18 @@ def preprocessing_fn(example: Dict) -> Dict[str, str]: _ALLOWED_RESPONSE_KEYS = {'response', 'completion'} _ALLOWED_PROMPT_KEYS = {'prompt'} +_ALLOWED_MESSAGES_KEYS = {'messages'} +_ALLOWED_ROLE_KEYS = {'role'} +_ALLOWED_CONTENT_KEYS = {'content'} +_ALLOWED_ROLES = {'user', 'assistant', 'system'} +_ALLOWED_LAST_MESSAGE_ROLES = {'assistant'} DOWNLOADED_FT_DATASETS_DIRPATH = os.path.abspath( os.path.join(os.path.realpath(__file__), os.pardir, os.pardir, os.pardir, '.downloaded_finetuning')) SUPPORTED_EXTENSIONS = ['.csv', '.jsonl', '.parquet'] -PromptResponseDict = Dict[str, str] -ChatFormattedDict = Dict[str, List[Dict[str, str]]] +PromptResponseDict = Mapping[str, str] +ChatFormattedDict = Mapping[str, List[Dict[str, str]]] Example = Union[PromptResponseDict, ChatFormattedDict] ExampleType = Literal['prompt_response', 'chat'] TokenizedExample = Dict[str, List[int]] @@ -79,7 +85,11 @@ def _get_example_type(example: Example) -> ExampleType: Raises: KeyError: If the example type is unknown. """ - if 'messages' in example: + if not isinstance(example, Mapping): + raise TypeError( + f'Expected example to be a Mapping, but found {type(example)}') + if any(allowed_message_key in example + for allowed_message_key in _ALLOWED_MESSAGES_KEYS): return 'chat' elif any([ pr in example @@ -102,6 +112,49 @@ def _is_empty_or_nonexistent(dirpath: str) -> bool: return not os.path.isdir(dirpath) or len(os.listdir(dirpath)) == 0 +def _get_key(dictionary: Mapping[str, Any], allowed_keys: Set[str]): + if not isinstance(dictionary, Mapping): + raise TypeError( + f'Expected dictionary to be a mapping, but found {type(dictionary)}' + ) + desired_keys = allowed_keys.intersection(dictionary.keys()) + if len(desired_keys) != 1: + raise ValueError( + f'Dictionary has multiple keys in `allowed_keys`: {desired_keys}') + return list(desired_keys)[0] + + +def _validate_chat_formatted_example(example: ChatFormattedDict): + if not isinstance(example, Mapping): + raise TypeError( + f'Expected example to be a mapping, but found {type(example)}') + messages = example[_get_key(example, _ALLOWED_MESSAGES_KEYS)] + if not isinstance(messages, List): + raise TypeError( + f'Expected messages to be an iterable, but found {type(messages)}') + if len(messages) <= 1: + raise ValueError('Chat example must have at least two messages') + + last_message = messages[-1] + role_key = _get_key(last_message, _ALLOWED_ROLE_KEYS) + last_role = last_message[role_key] + if last_role not in _ALLOWED_LAST_MESSAGE_ROLES: + raise ValueError(f'Invalid last message role: {last_role}') + + for message in messages: + role_key, content_key = _get_key(message, _ALLOWED_ROLE_KEYS), _get_key( + message, _ALLOWED_CONTENT_KEYS) + if len(message.keys()) != 2: + raise ValueError( + f'Expected 2 keys in message, but found {len(message.keys())}') + if message[role_key] not in _ALLOWED_ROLES: + raise ValueError(f'Invalid role: {message[role_key]}') + if not isinstance(message[content_key], str): + raise TypeError( + f'Expected content to be a string, but found {type(message[content_key])}' + ) + + def _slice_chat_formatted_example( example: ChatFormattedDict, tokenizer: PreTrainedTokenizerBase) -> Tuple[str, str]: @@ -118,18 +171,13 @@ def _slice_chat_formatted_example( ValueError: If the chat example has less than two messages or if the last message is not from the assistant. KeyError: If a message does not have a role or content. """ - messages = example['messages'] + _validate_chat_formatted_example(example) + messages = example[_get_key(example, _ALLOWED_MESSAGES_KEYS)] - if len(messages) < 2: - raise ValueError( - f'chat example must have at least two messages. {messages=}') last_message = messages[-1] if last_message['role'] != 'assistant': raise ValueError( f'last message must be from assistant. {last_message=}') - for message in messages: - if 'role' not in message or 'content' not in message: - raise KeyError(f'message must have role and content. {message=}') full_conversation = tokenizer.apply_chat_template(messages, tokenize=False) prompt = tokenizer.apply_chat_template(messages[:-1], diff --git a/tests/data/test_template_tokenization.py b/tests/data/test_template_tokenization.py index 829d1ebbc0..5491b94521 100644 --- a/tests/data/test_template_tokenization.py +++ b/tests/data/test_template_tokenization.py @@ -42,8 +42,10 @@ def test_tokenize_chat_example_malformed(): 'content': 'user message not followed by an assistant label' }] } + wrong_type = {'messages': 'this is not a list of messages'} malformed_chat_examples = [ - too_few_messages, no_content, ends_with_user_role, no_assistant_message + too_few_messages, no_content, ends_with_user_role, no_assistant_message, + wrong_type ] my_tokenizer = build_tokenizer('mosaicml/mpt-7b-8k-chat', {}) for example in malformed_chat_examples: From 78cbe08523f6fe187177b84832d0b6e16cb5c3b9 Mon Sep 17 00:00:00 2001 From: Daniel King <43149077+dakinggg@users.noreply.github.com> Date: Mon, 12 Feb 2024 11:31:42 -0800 Subject: [PATCH 2/4] Add new base images and remove fa1 images (#970) --- .github/workflows/docker.yaml | 16 +++++----------- README.md | 8 +++----- 2 files changed, 8 insertions(+), 16 deletions(-) diff --git a/.github/workflows/docker.yaml b/.github/workflows/docker.yaml index e62b01fa52..2ebbbd69f7 100644 --- a/.github/workflows/docker.yaml +++ b/.github/workflows/docker.yaml @@ -12,22 +12,16 @@ on: workflow_dispatch: {} jobs: docker-build: - runs-on: ubuntu-latest + runs-on: mosaic-4wide if: github.repository_owner == 'mosaicml' strategy: matrix: include: - - name: "2.1.0_cu121" - base_image: mosaicml/pytorch:2.1.0_cu121-python3.10-ubuntu20.04 - dep_groups: "[gpu]" - - name: "2.1.0_cu121_flash2" - base_image: mosaicml/pytorch:2.1.0_cu121-python3.10-ubuntu20.04 + - name: "2.1.2_cu121_flash2" + base_image: mosaicml/pytorch:2.1.2_cu121-python3.10-ubuntu20.04 dep_groups: "[gpu-flash2]" - - name: "2.1.0_cu121_aws" - base_image: mosaicml/pytorch:2.1.0_cu121-python3.10-ubuntu20.04-aws - dep_groups: "[gpu]" - - name: "2.1.0_cu121_flash2_aws" - base_image: mosaicml/pytorch:2.1.0_cu121-python3.10-ubuntu20.04-aws + - name: "2.1.2_cu121_flash2_aws" + base_image: mosaicml/pytorch:2.1.2_cu121-python3.10-ubuntu20.04-aws dep_groups: "[gpu-flash2]" steps: - name: Maximize Build Space on Worker diff --git a/README.md b/README.md index 8ffe222cb7..6668476fd4 100644 --- a/README.md +++ b/README.md @@ -113,11 +113,9 @@ You can select a specific commit hash such as `mosaicml/llm-foundry:1.13.1_cu117 | Docker Image | Torch Version | Cuda Version | LLM Foundry dependencies installed? | | ------------------------------------------------------ | ------------- | ----------------- | ----------------------------------- | -| `mosaicml/pytorch:2.1.0_cu121-python3.10-ubuntu20.04` | 2.1.0 | 12.1 (Infiniband) | No | -| `mosaicml/llm-foundry:2.1.0_cu121-latest` | 2.1.0 | 12.1 (Infiniband) | Yes (flash attention v1. Warning: Support for flash attention v1 has been deprecated.) | -| `mosaicml/llm-foundry:2.1.0_cu121_flash2-latest` | 2.1.0 | 12.1 (Infiniband) | Yes (flash attention v2. Note: We recommend using flash attention v2.) | -| `mosaicml/llm-foundry:2.1.0_cu121_aws-latest` | 2.1.0 | 12.1 (EFA) | Yes (flash attention v1. Warning: Support for flash attention v1 has been deprecated.) | -| `mosaicml/llm-foundry:2.1.0_cu121_flash2_aws-latest` | 2.1.0 | 12.1 (EFA) | Yes (flash attention v2. Note: We recommend using flash attention v2.) | +| `mosaicml/pytorch:2.1.2_cu121-python3.10-ubuntu20.04` | 2.1.2 | 12.1 (Infiniband) | No | +| `mosaicml/llm-foundry:2.1.2_cu121_flash2-latest` | 2.1.2 | 12.1 (Infiniband) | Yes | +| `mosaicml/llm-foundry:2.1.2_cu121_flash2_aws-latest` | 2.1.2 | 12.1 (EFA) | Yes | # Installation From e88cdf1166c02a610f1c278af334653aabf73e74 Mon Sep 17 00:00:00 2001 From: Max Marion Date: Mon, 12 Feb 2024 15:31:24 -0800 Subject: [PATCH 3/4] Add new ICL kwargs in eval.py and long_context yamls (#925) * add yamls w/ old links * load from max's public hf and parse hf datasets * update rest of tasks * add better logging * implemented leval tasks * move level * add level yaml * add str parsing to hf * wip * llm-foundry working with new parser * working w/ new parsing * fix old long context tasks * wip * wip * wip * wip * update to hf_parsing_map * rm defaults * fix parsing vars * update defaults again * rm merge conflict * fix gen_kwargs * rm old code path * fixups * wip * rm leval from pr * fix comments in yamls * add cot params * add fewshot_random_seed * fix early_stopping_criteria, fewshot_num_seed default * undo rm hf_eval * add fewshot_random_seed to test * add 64k tasks * add longer context, update composer versin * address comments * mixed * use seed by default * rm long_context_eval_8k.yaml * add longer context evals * mv yamls * eval gauntlet wip * update niah and wikiqa * fix linting * add default option * change defaults * fix linting * fix linting 2 --------- Co-authored-by: Daniel King <43149077+dakinggg@users.noreply.github.com> --- llmfoundry/utils/builders.py | 14 + scripts/eval/eval.py | 1 + scripts/eval/local_data/EVAL_GAUNTLET.md | 42 ++ .../eval_gauntlet_long_context_length.yaml | 134 ++++++ .../eval_gauntlet_long_context_section.yaml | 130 ++++++ scripts/eval/yamls/hf_eval.yaml | 10 +- scripts/eval/yamls/long_context_tasks.yaml | 388 ++++++++++++++++++ tests/utils/test_builders.py | 1 + 8 files changed, 715 insertions(+), 5 deletions(-) create mode 100644 scripts/eval/yamls/eval_gauntlet_long_context_length.yaml create mode 100644 scripts/eval/yamls/eval_gauntlet_long_context_section.yaml create mode 100644 scripts/eval/yamls/long_context_tasks.yaml diff --git a/llmfoundry/utils/builders.py b/llmfoundry/utils/builders.py index 644c8e0c78..fb3a0d97f8 100644 --- a/llmfoundry/utils/builders.py +++ b/llmfoundry/utils/builders.py @@ -53,6 +53,7 @@ def build_evaluators( device_eval_batch_size: int, icl_seq_len: int, icl_subset_num_batches: Optional[int], + fewshot_random_seed: Optional[int] = 1234, ) -> Tuple[List[Evaluator], List[str], Optional[EvalGauntlet]]: evaluators = [] @@ -72,6 +73,7 @@ def build_evaluators( tokenizer, device_eval_batch_size, icl_seq_len, + fewshot_random_seed, icl_subset_num_batches, ) evaluators.extend(icl_evaluators) @@ -128,6 +130,7 @@ def build_icl_data_and_gauntlet( tokenizer: PreTrainedTokenizerBase, device_eval_batch_size: int, icl_seq_len: int, + fewshot_random_seed: Optional[int] = 1234, icl_subset_num_batches: Optional[int] = None ) -> Tuple[List[Evaluator], List[str], Optional[EvalGauntlet]]: icl_evaluators, logger_keys = build_icl_evaluators( @@ -135,6 +138,7 @@ def build_icl_data_and_gauntlet( tokenizer, icl_seq_len, device_eval_batch_size, + fewshot_random_seed=fewshot_random_seed, icl_subset_num_batches=icl_subset_num_batches) eval_gauntlet_cb = None if eval_gauntlet_config is not None: @@ -442,6 +446,7 @@ def build_icl_evaluators( default_max_seq_len: int, default_batch_size: int, destination_dir: Optional[str] = None, + fewshot_random_seed: Optional[int] = 1234, icl_subset_num_batches: Optional[int] = None, ) -> Tuple[List[Evaluator], List[str]]: if destination_dir is None: @@ -516,6 +521,10 @@ def _validate_cfg(icl_cfg: DictConfig): if dist.get_local_rank() == 0 and os.path.exists(destination_path): os.remove(destination_path) dist.barrier() + + hf_parsing_map = icl_cfg.get('hf_parsing_map', {}) + hf_loading_vars = icl_cfg.get('hf_loading_vars', {}) + early_stopping_criteria = icl_cfg.get('early_stopping_criteria', None) if isinstance(early_stopping_criteria, ListConfig): @@ -533,13 +542,18 @@ def _validate_cfg(icl_cfg: DictConfig): num_fewshot=num_fewshot, prompt_string=icl_cfg.prompt_string, example_delimiter=icl_cfg.example_delimiter, + hf_loading_vars=hf_loading_vars, + hf_parsing_map=hf_parsing_map, continuation_delimiter=icl_cfg.continuation_delimiter, question_prelimiter=icl_cfg.get('question_prelimiter', ''), destination_path=destination_path, + fewshot_random_seed=icl_cfg.get('fewshot_random_seed', + fewshot_random_seed), pass_at_k=icl_cfg.pass_at_k, generations_per_sample=icl_cfg.num_beams, has_categories=icl_cfg.get('has_categories', False), cot_delimiter=icl_cfg.get('cot_delimiter', ''), + generation_kwargs=icl_cfg.get('generation_kwargs', {}), early_stopping_criteria=early_stopping_criteria, do_normalization=icl_cfg.get('do_normalization', True)) if hasattr( diff --git a/scripts/eval/eval.py b/scripts/eval/eval.py index e36e08575b..9c8dad0977 100644 --- a/scripts/eval/eval.py +++ b/scripts/eval/eval.py @@ -95,6 +95,7 @@ def evaluate_model( tokenizer=tokenizer, device_eval_batch_size=device_eval_batch_size, icl_seq_len=max_seq_len, + fewshot_random_seed=seed, icl_subset_num_batches=icl_subset_num_batches, ) diff --git a/scripts/eval/local_data/EVAL_GAUNTLET.md b/scripts/eval/local_data/EVAL_GAUNTLET.md index 8d84849b6b..ab11ea71de 100644 --- a/scripts/eval/local_data/EVAL_GAUNTLET.md +++ b/scripts/eval/local_data/EVAL_GAUNTLET.md @@ -253,3 +253,45 @@ Programming tasks evaluate the model's ability to understand code, write functio - Year released: 2022 - Number of few shot examples: 0 - Random baseline accuracy: 0% +54. HumanEval Python 25% code generation + - Description: HumanEval Python 25% is an easier variant of HumanEval Python in which in addition to the original method signature, the model is also provided 25% of the lines in the canonical solution and expected to complete the reaminder of the program. It consists of 164 samples. + - Year released: 2023 + - Number of few shot examples: 0 + - Random baseline accuracy: 0% +55. HumanEval Python 50% code generation + - Description: HumanEval Python 50% is an easier variant of HumanEval Python in which in addition to the original method signature, the model is also provided 50% of the lines in the canonical solution and expected to complete the reaminder of the program. It consists of 164 samples. + - Year released: 2023 + - Number of few shot examples: 0 + - Random baseline accuracy: 0% +56. HumanEval Python 75% code generation + - Description: HumanEval Python 75% is an easier variant of HumanEval Python in which in addition to the original method signature, the model is also provided 75% of the lines in the canonical solution and expected to complete the reaminder of the program. It consists of 164 samples. + - Year released: 2023 + - Number of few shot examples: 0 + - Random baseline accuracy: 0% +57. HumanEval Python simple return statement code generation + - Description: HumanEval Python simple return statament is an easier variant of HumanEval Python in which the model is provided all of the canonical solution with the exception of the return statement and is expected to complete the return statement. Additionally, this set contains only the problems for which the canonical solution has a "simple" return statement consisting only of a line of the form `return VARIABLE\_NAME`. There are 37 samples. + - Year released: 2023 + - Number of few shot examples: 0 + - Random baseline accuracy: 0% +58. HumanEval Python complex return statement code generation + - Description: HumanEval Pythom complex return statament is an easier variant of HumanEval Python in which the model is provided all of the canonical solution with the exception of the return statement and is expected to complete the return statement. Additionally, this set contains only the problems for which the canonical solution does not have a "simple" return statement as defined above. There are 127 samples. + - Year released: 2023 + - Number of few shot examples: 0 + - Random baseline accuracy: 0% + +### Long Context Gauntlet + +We've included three different tasks for long (> 4000 tokens) context length evals. They are meant as litmus tests for a model's ability to properly utilize it's longer context length, which is often the result of fine-tuning after pre-training. For some of these datasets, we explicitly create sets where the required information is located in different sections of the input context, either the beginning, middle, or end of the input context. + +1. HotPotQAXL + - Description: (HotPotQA)[https://hotpotqa.github.io/] is originally a dataset of ten documents and a question requiring comprehension of one or more of the supplied documents. The non-related documents are completely unrelated and called "distractor" documents. To extend this to longer context lengths, we randomly sample documents from the full set of documents across the dataset, adding them to the current datapoint until the set of documents and its question fills the current context length. We insert the "gold" document(s) (the document(s) containing the information that answers the question) within the first third, second third, or last third of the context length. + - Lengths: 2k, 4k, 8k, 16k, 32k, 64k + - Locations: beginning, middle, end +2. Key Value Pairs (Needle In a Haystack) + - Description: We construct a `.json` of key value pairs, where both the key and value are random hashes, in the style of (Lost in the Middle)[https://github.com/nelson-liu/lost-in-the-middle]. We ask the model to produce a value given a key from a specific key value pair found int he json. The pair is correspondingly located in the first third, second third, or last third of the json. + - Lengths: 2k, 4k, 8k, 16k, 32k, 64k + - Locations: beginning, middle, end +2. WikiQA Numeric + - Description: (WikiQA Numeric)[https://huggingface.co/datasets/abacusai/WikiQA-Altered_Numeric_QA] is a Wikipedia Question Answering dataset with a focus on questions with numeric answers. We preprocess the data only to easily parse it for our framework. + - Lengths: 2k, 4k, 8k, 16k + - Locations: N/A diff --git a/scripts/eval/yamls/eval_gauntlet_long_context_length.yaml b/scripts/eval/yamls/eval_gauntlet_long_context_length.yaml new file mode 100644 index 0000000000..bcb52bf658 --- /dev/null +++ b/scripts/eval/yamls/eval_gauntlet_long_context_length.yaml @@ -0,0 +1,134 @@ +eval_gauntlet: + weighting: EQUAL + subtract_random_baseline: true + rescale_accuracy: true + categories: + - name: 2k + benchmarks: + - name: hotpotqa_beginning_2k + num_fewshot: 0 + random_baseline: 0 + - name: hotpotqa_middle_2k + num_fewshot: 0 + random_baseline: 0 + - name: hotpotqa_end_2k + num_fewshot: 0 + random_baseline: 0 + - name: kv_pairs_beginning_2k + num_fewshot: 0 + random_baseline: 0 + - name: kv_pairs_middle_2k + num_fewshot: 0 + random_baseline: 0 + - name: kv_pairs_end_2k + num_fewshot: 0 + random_baseline: 0 + - name: wikiqa_2k + num_fewshot: 0 + random_baseline: 0 + - name: 4k + benchmarks: + - name: hotpotqa_beginning_4k + num_fewshot: 0 + random_baseline: 0 + - name: hotpotqa_middle_4k + num_fewshot: 0 + random_baseline: 0 + - name: hotpotqa_end_4k + num_fewshot: 0 + random_baseline: 0 + - name: kv_pairs_beginning_4k + num_fewshot: 0 + random_baseline: 0 + - name: kv_pairs_middle_4k + num_fewshot: 0 + random_baseline: 0 + - name: kv_pairs_end_4k + num_fewshot: 0 + random_baseline: 0 + - name: wikiqa_4k + num_fewshot: 0 + random_baseline: 0 + - name: 8k + benchmarks: + - name: hotpotqa_beginning_8k + num_fewshot: 0 + random_baseline: 0 + - name: hotpotqa_middle_8k + num_fewshot: 0 + random_baseline: 0 + - name: hotpotqa_end_8k + num_fewshot: 0 + random_baseline: 0 + - name: kv_pairs_beginning_8k + num_fewshot: 0 + random_baseline: 0 + - name: kv_pairs_middle_8k + num_fewshot: 0 + random_baseline: 0 + - name: kv_pairs_end_8k + num_fewshot: 0 + random_baseline: 0 + - name: wikiqa_8k + num_fewshot: 0 + random_baseline: 0 + - name: 16k + benchmarks: + - name: hotpotqa_beginning_16k + num_fewshot: 0 + random_baseline: 0 + - name: hotpotqa_middle_16k + num_fewshot: 0 + random_baseline: 0 + - name: hotpotqa_end_16k + num_fewshot: 0 + random_baseline: 0 + - name: kv_pairs_beginning_16k + num_fewshot: 0 + random_baseline: 0 + - name: kv_pairs_middle_16k + num_fewshot: 0 + random_baseline: 0 + - name: kv_pairs_end_16k + num_fewshot: 0 + random_baseline: 0 + - name: 32k + benchmarks: + - name: hotpotqa_beginning_32k + num_fewshot: 0 + random_baseline: 0 + - name: hotpotqa_middle_32k + num_fewshot: 0 + random_baseline: 0 + - name: hotpotqa_end_32k + num_fewshot: 0 + random_baseline: 0 + - name: kv_pairs_beginning_32k + num_fewshot: 0 + random_baseline: 0 + - name: kv_pairs_middle_32k + num_fewshot: 0 + random_baseline: 0 + - name: kv_pairs_end_32k + num_fewshot: 0 + random_baseline: 0 + - name: 64k + benchmarks: + - name: hotpotqa_beginning_64k + num_fewshot: 0 + random_baseline: 0 + - name: hotpotqa_middle_64k + num_fewshot: 0 + random_baseline: 0 + - name: hotpotqa_end_64k + num_fewshot: 0 + random_baseline: 0 + - name: kv_pairs_beginning_64k + num_fewshot: 0 + random_baseline: 0 + - name: kv_pairs_middle_64k + num_fewshot: 0 + random_baseline: 0 + - name: kv_pairs_end_64k + num_fewshot: 0 + random_baseline: 0 diff --git a/scripts/eval/yamls/eval_gauntlet_long_context_section.yaml b/scripts/eval/yamls/eval_gauntlet_long_context_section.yaml new file mode 100644 index 0000000000..f776047c38 --- /dev/null +++ b/scripts/eval/yamls/eval_gauntlet_long_context_section.yaml @@ -0,0 +1,130 @@ +eval_gauntlet: + weighting: EQUAL + subtract_random_baseline: true + rescale_accuracy: true + categories: + - name: beginning + benchmarks: + - name: hotpotqa_beginning_2k + num_fewshot: 0 + random_baseline: 0 + - name: kv_pairs_beginning_2k + num_fewshot: 0 + random_baseline: 0 + - name: hotpotqa_beginning_4k + num_fewshot: 0 + random_baseline: 0 + - name: kv_pairs_beginning_4k + num_fewshot: 0 + random_baseline: 0 + - name: hotpotqa_beginning_8k + num_fewshot: 0 + random_baseline: 0 + - name: kv_pairs_beginning_8k + num_fewshot: 0 + random_baseline: 0 + - name: hotpotqa_beginning_16k + num_fewshot: 0 + random_baseline: 0 + - name: kv_pairs_beginning_16k + num_fewshot: 0 + random_baseline: 0 + - name: hotpotqa_beginning_32k + num_fewshot: 0 + random_baseline: 0 + - name: kv_pairs_beginning_32k + num_fewshot: 0 + random_baseline: 0 + - name: hotpotqa_beginning_64k + num_fewshot: 0 + random_baseline: 0 + - name: kv_pairs_beginning_64k + num_fewshot: 0 + random_baseline: 0 + - name: middle + benchmarks: + - name: hotpotqa_middle_2k + num_fewshot: 0 + random_baseline: 0 + - name: kv_pairs_middle_2k + num_fewshot: 0 + random_baseline: 0 + - name: hotpotqa_middle_4k + num_fewshot: 0 + random_baseline: 0 + - name: kv_pairs_middle_4k + num_fewshot: 0 + random_baseline: 0 + - name: hotpotqa_middle_8k + num_fewshot: 0 + random_baseline: 0 + - name: kv_pairs_middle_8k + num_fewshot: 0 + random_baseline: 0 + - name: hotpotqa_middle_16k + num_fewshot: 0 + random_baseline: 0 + - name: kv_pairs_middle_16k + num_fewshot: 0 + random_baseline: 0 + - name: hotpotqa_middle_32k + num_fewshot: 0 + random_baseline: 0 + - name: kv_pairs_middle_32k + num_fewshot: 0 + random_baseline: 0 + - name: hotpotqa_middle_64k + num_fewshot: 0 + random_baseline: 0 + - name: kv_pairs_middle_64k + num_fewshot: 0 + random_baseline: 0 + - name: end + benchmarks: + - name: hotpotqa_end_2k + num_fewshot: 0 + random_baseline: 0 + - name: kv_pairs_end_2k + num_fewshot: 0 + random_baseline: 0 + - name: hotpotqa_end_4k + num_fewshot: 0 + random_baseline: 0 + - name: kv_pairs_end_4k + num_fewshot: 0 + random_baseline: 0 + - name: hotpotqa_end_8k + num_fewshot: 0 + random_baseline: 0 + - name: kv_pairs_end_8k + num_fewshot: 0 + random_baseline: 0 + - name: hotpotqa_end_16k + num_fewshot: 0 + random_baseline: 0 + - name: kv_pairs_end_16k + num_fewshot: 0 + random_baseline: 0 + - name: hotpotqa_end_32k + num_fewshot: 0 + random_baseline: 0 + - name: kv_pairs_end_32k + num_fewshot: 0 + random_baseline: 0 + - name: hotpotqa_end_64k + num_fewshot: 0 + random_baseline: 0 + - name: kv_pairs_end_64k + num_fewshot: 0 + random_baseline: 0 + - name: full + benchmarks: + - name: wikiqa_2k + num_fewshot: 0 + random_baseline: 0 + - name: wikiqa_4k + num_fewshot: 0 + random_baseline: 0 + - name: wikiqa_8k + num_fewshot: 0 + random_baseline: 0 diff --git a/scripts/eval/yamls/hf_eval.yaml b/scripts/eval/yamls/hf_eval.yaml index bf0426b357..9eb0245f9a 100644 --- a/scripts/eval/yamls/hf_eval.yaml +++ b/scripts/eval/yamls/hf_eval.yaml @@ -37,11 +37,11 @@ models: device_eval_batch_size: 4 # FSDP config for model sharding -# fsdp_config: -# sharding_strategy: FULL_SHARD -# mixed_precision: FULL -# forward_prefetch: True -# limit_all_gathers: True +fsdp_config: + sharding_strategy: FULL_SHARD + mixed_precision: FULL + forward_prefetch: True + limit_all_gathers: True icl_tasks: "eval/yamls/tasks_v0.3.yaml" eval_gauntlet: "eval/yamls/eval_gauntlet_v0.3.yaml" diff --git a/scripts/eval/yamls/long_context_tasks.yaml b/scripts/eval/yamls/long_context_tasks.yaml new file mode 100644 index 0000000000..daf958a340 --- /dev/null +++ b/scripts/eval/yamls/long_context_tasks.yaml @@ -0,0 +1,388 @@ +icl_tasks: +- + label: kv_pairs_beginning_2k + dataset_uri: hf://mosaicml/long_context_eval + num_fewshot: [0] + icl_task_type: question_answering + hf_loading_vars: + name: kv_pairs + context_length: 2048 + section: beginning + split: test +- + label: kv_pairs_middle_2k + dataset_uri: hf://mosaicml/long_context_eval + num_fewshot: [0] + icl_task_type: question_answering + hf_loading_vars: + name: kv_pairs + context_length: 2048 + section: middle + split: test +- + label: kv_pairs_end_2k + dataset_uri: hf://mosaicml/long_context_eval + num_fewshot: [0] + icl_task_type: question_answering + hf_loading_vars: + name: kv_pairs + context_length: 2048 + section: end + split: test +- + label: kv_pairs_beginning_4k + dataset_uri: hf://mosaicml/long_context_eval + num_fewshot: [0] + icl_task_type: question_answering + hf_loading_vars: + name: kv_pairs + context_length: 4096 + section: beginning + split: test +- + label: kv_pairs_middle_4k + dataset_uri: hf://mosaicml/long_context_eval + num_fewshot: [0] + icl_task_type: question_answering + hf_loading_vars: + name: kv_pairs + context_length: 4096 + section: middle + split: test +- + label: kv_pairs_end_4k + dataset_uri: hf://mosaicml/long_context_eval + num_fewshot: [0] + icl_task_type: question_answering + hf_loading_vars: + name: kv_pairs + context_length: 4096 + section: end + split: test +- + label: kv_pairs_beginning_8k + dataset_uri: hf://mosaicml/long_context_eval + num_fewshot: [0] + icl_task_type: question_answering + hf_loading_vars: + name: kv_pairs + context_length: 8192 + section: beginning + split: test +- + label: kv_pairs_middle_8k + dataset_uri: hf://mosaicml/long_context_eval + num_fewshot: [0] + icl_task_type: question_answering + hf_loading_vars: + name: kv_pairs + context_length: 8192 + section: middle + split: test +- + label: kv_pairs_end_8k + dataset_uri: hf://mosaicml/long_context_eval + num_fewshot: [0] + icl_task_type: question_answering + hf_loading_vars: + name: kv_pairs + context_length: 8192 + section: end + split: test +- + label: wikiqa_2k + dataset_uri: hf://mosaicml/long_context_eval + num_fewshot: [0] + icl_task_type: question_answering + hf_loading_vars: + name: wikiqa + context_length: 2048 + split: test +- + label: wikiqa_4k + dataset_uri: hf://mosaicml/long_context_eval + num_fewshot: [0] + icl_task_type: question_answering + hf_loading_vars: + name: wikiqa + context_length: 2048 + split: test +- + label: wikiqa_8k + dataset_uri: hf://mosaicml/long_context_eval + num_fewshot: [0] + icl_task_type: question_answering + hf_loading_vars: + name: wikiqa + context_length: 2048 + split: test +- + label: hotpotqa_beginning_2k + dataset_uri: hf://mosaicml/long_context_eval + num_fewshot: [0] + icl_task_type: question_answering + hf_loading_vars: + name: hotpotqa + context_length: 2048 + section: beginning + split: test +- + label: hotpotqa_middle_2k + dataset_uri: hf://mosaicml/long_context_eval + num_fewshot: [0] + icl_task_type: question_answering + hf_loading_vars: + name: hotpotqa + context_length: 2048 + section: middle + split: test +- + label: hotpotqa_end_2k + dataset_uri: hf://mosaicml/long_context_eval + num_fewshot: [0] + icl_task_type: question_answering + hf_loading_vars: + name: hotpotqa + context_length: 2048 + section: end + split: test +- + label: hotpotqa_beginning_4k + dataset_uri: hf://mosaicml/long_context_eval + num_fewshot: [0] + icl_task_type: question_answering + hf_loading_vars: + name: hotpotqa + context_length: 4096 + section: beginning + split: test +- + label: hotpotqa_middle_4k + dataset_uri: hf://mosaicml/long_context_eval + num_fewshot: [0] + icl_task_type: question_answering + hf_loading_vars: + name: hotpotqa + context_length: 4096 + section: middle + split: test +- + label: hotpotqa_end_4k + dataset_uri: hf://mosaicml/long_context_eval + num_fewshot: [0] + icl_task_type: question_answering + hf_loading_vars: + name: hotpotqa + context_length: 4096 + section: end + split: test +- + label: hotpotqa_beginning_8k + dataset_uri: hf://mosaicml/long_context_eval + num_fewshot: [0] + icl_task_type: question_answering + hf_loading_vars: + name: hotpotqa + context_length: 8192 + section: beginning + split: test +- + label: hotpotqa_middle_8k + dataset_uri: hf://mosaicml/long_context_eval + num_fewshot: [0] + icl_task_type: question_answering + hf_loading_vars: + name: hotpotqa + context_length: 8192 + section: middle + split: test +- + label: hotpotqa_end_8k + dataset_uri: hf://mosaicml/long_context_eval + num_fewshot: [0] + icl_task_type: question_answering + hf_loading_vars: + name: hotpotqa + context_length: 8192 + section: end + split: test +- + label: hotpotqa_beginning_16k + dataset_uri: hf://mosaicml/long_context_eval + num_fewshot: [0] + icl_task_type: question_answering + hf_loading_vars: + name: hotpotqa + context_length: 16384 + section: beginning + split: test +- + label: hotpotqa_beginning_32k + dataset_uri: hf://mosaicml/long_context_eval + num_fewshot: [0] + icl_task_type: question_answering + hf_loading_vars: + name: hotpotqa + context_length: 32768 + section: beginning + split: test +- + label: hotpotqa_beginning_64k + dataset_uri: hf://mosaicml/long_context_eval + num_fewshot: [0] + icl_task_type: question_answering + hf_loading_vars: + name: hotpotqa + context_length: 65536 + section: beginning + split: test +- + label: hotpotqa_middle_16k + dataset_uri: hf://mosaicml/long_context_eval + num_fewshot: [0] + icl_task_type: question_answering + hf_loading_vars: + name: hotpotqa + context_length: 16384 + section: middle + split: test +- + label: hotpotqa_middle_32k + dataset_uri: hf://mosaicml/long_context_eval + num_fewshot: [0] + icl_task_type: question_answering + hf_loading_vars: + name: hotpotqa + context_length: 32768 + section: middle + split: test +- + label: hotpotqa_middle_64k + dataset_uri: hf://mosaicml/long_context_eval + num_fewshot: [0] + icl_task_type: question_answering + hf_loading_vars: + name: hotpotqa + context_length: 65536 + section: middle + split: test +- + label: hotpotqa_end_16k + dataset_uri: hf://mosaicml/long_context_eval + num_fewshot: [0] + icl_task_type: question_answering + hf_loading_vars: + name: hotpotqa + context_length: 16384 + section: end + split: test +- + label: hotpotqa_end_32k + dataset_uri: hf://mosaicml/long_context_eval + num_fewshot: [0] + icl_task_type: question_answering + hf_loading_vars: + name: hotpotqa + context_length: 32768 + section: end + split: test +- + label: hotpotqa_end_64k + dataset_uri: hf://mosaicml/long_context_eval + num_fewshot: [0] + icl_task_type: question_answering + hf_loading_vars: + name: hotpotqa + context_length: 65536 + section: end + split: test +- + label: kv_pairs_beginning_16k + dataset_uri: hf://mosaicml/long_context_eval + num_fewshot: [0] + icl_task_type: question_answering + hf_loading_vars: + name: hotpotqa + context_length: 16384 + section: beginning + split: test +- + label: kv_pairs_beginning_32k + dataset_uri: hf://mosaicml/long_context_eval + num_fewshot: [0] + icl_task_type: question_answering + hf_loading_vars: + name: hotpotqa + context_length: 32768 + section: beginning + split: test +- + label: kv_pairs_beginning_64k + dataset_uri: hf://mosaicml/long_context_eval + num_fewshot: [0] + icl_task_type: question_answering + hf_loading_vars: + name: hotpotqa + context_length: 65536 + section: beginning + split: test +- + label: kv_pairs_middle_16k + dataset_uri: hf://mosaicml/long_context_eval + num_fewshot: [0] + icl_task_type: question_answering + hf_loading_vars: + name: hotpotqa + context_length: 16384 + section: middle + split: test +- + label: kv_pairs_middle_32k + dataset_uri: hf://mosaicml/long_context_eval + num_fewshot: [0] + icl_task_type: question_answering + hf_loading_vars: + name: hotpotqa + context_length: 32768 + section: middle + split: test +- + label: kv_pairs_middle_64k + dataset_uri: hf://mosaicml/long_context_eval + num_fewshot: [0] + icl_task_type: question_answering + hf_loading_vars: + name: hotpotqa + context_length: 65536 + section: middle + split: test +- + label: kv_pairs_end_16k + dataset_uri: hf://mosaicml/long_context_eval + num_fewshot: [0] + icl_task_type: question_answering + hf_loading_vars: + name: hotpotqa + context_length: 16384 + section: end + split: test +- + label: kv_pairs_end_32k + dataset_uri: hf://mosaicml/long_context_eval + num_fewshot: [0] + icl_task_type: question_answering + hf_loading_vars: + name: hotpotqa + context_length: 32768 + section: end + split: test +- + label: kv_pairs_end_64k + dataset_uri: hf://mosaicml/long_context_eval + num_fewshot: [0] + icl_task_type: question_answering + hf_loading_vars: + name: hotpotqa + context_length: 65536 + section: end + split: test diff --git a/tests/utils/test_builders.py b/tests/utils/test_builders.py index 08c3504491..81d8a841c7 100644 --- a/tests/utils/test_builders.py +++ b/tests/utils/test_builders.py @@ -250,6 +250,7 @@ def test_build_evaluators_empty(): None, tokenizer=None, # type: ignore device_eval_batch_size=1, + fewshot_random_seed=1234, icl_seq_len=2, icl_subset_num_batches=3) assert evaluators == [] From bce53741b56e977c37d937cac582a2c03481174c Mon Sep 17 00:00:00 2001 From: Daniel King <43149077+dakinggg@users.noreply.github.com> Date: Mon, 12 Feb 2024 17:12:09 -0800 Subject: [PATCH 4/4] Make Composer pins consistent with each other (#972) --- setup.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/setup.py b/setup.py index c88f566f26..8fdfc2197b 100644 --- a/setup.py +++ b/setup.py @@ -88,14 +88,14 @@ ] extra_deps['databricks'] = [ - 'mosaicml[databricks]>=0.19,<0.20', + 'mosaicml[databricks]>=0.19.1,<0.20', 'databricks-sql-connector>=3,<4', 'databricks-connect==14.1.0', 'lz4>=4,<5', ] extra_deps['tensorboard'] = [ - 'mosaicml[tensorboard]>=0.19,<0.20', + 'mosaicml[tensorboard]>=0.19.1,<0.20', ] extra_deps['gpu'] = [ @@ -110,7 +110,7 @@ ] extra_deps['peft'] = [ - 'mosaicml[peft]>=0.19,<0.20', + 'mosaicml[peft]>=0.19.1,<0.20', ] extra_deps['openai'] = [