Skip to content

Commit

Permalink
Merge branch 'main' into batch_code_eval
Browse files Browse the repository at this point in the history
  • Loading branch information
dakinggg authored Feb 14, 2024
2 parents f6e3f1f + bce5374 commit 7541f89
Show file tree
Hide file tree
Showing 13 changed files with 788 additions and 36 deletions.
16 changes: 5 additions & 11 deletions .github/workflows/docker.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,22 +12,16 @@ on:
workflow_dispatch: {}
jobs:
docker-build:
runs-on: ubuntu-latest
runs-on: mosaic-4wide
if: github.repository_owner == 'mosaicml'
strategy:
matrix:
include:
- name: "2.1.0_cu121"
base_image: mosaicml/pytorch:2.1.0_cu121-python3.10-ubuntu20.04
dep_groups: "[gpu]"
- name: "2.1.0_cu121_flash2"
base_image: mosaicml/pytorch:2.1.0_cu121-python3.10-ubuntu20.04
- name: "2.1.2_cu121_flash2"
base_image: mosaicml/pytorch:2.1.2_cu121-python3.10-ubuntu20.04
dep_groups: "[gpu-flash2]"
- name: "2.1.0_cu121_aws"
base_image: mosaicml/pytorch:2.1.0_cu121-python3.10-ubuntu20.04-aws
dep_groups: "[gpu]"
- name: "2.1.0_cu121_flash2_aws"
base_image: mosaicml/pytorch:2.1.0_cu121-python3.10-ubuntu20.04-aws
- name: "2.1.2_cu121_flash2_aws"
base_image: mosaicml/pytorch:2.1.2_cu121-python3.10-ubuntu20.04-aws
dep_groups: "[gpu-flash2]"
steps:
- name: Maximize Build Space on Worker
Expand Down
8 changes: 3 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -113,11 +113,9 @@ You can select a specific commit hash such as `mosaicml/llm-foundry:1.13.1_cu117

| Docker Image | Torch Version | Cuda Version | LLM Foundry dependencies installed? |
| ------------------------------------------------------ | ------------- | ----------------- | ----------------------------------- |
| `mosaicml/pytorch:2.1.0_cu121-python3.10-ubuntu20.04` | 2.1.0 | 12.1 (Infiniband) | No |
| `mosaicml/llm-foundry:2.1.0_cu121-latest` | 2.1.0 | 12.1 (Infiniband) | Yes (flash attention v1. Warning: Support for flash attention v1 has been deprecated.) |
| `mosaicml/llm-foundry:2.1.0_cu121_flash2-latest` | 2.1.0 | 12.1 (Infiniband) | Yes (flash attention v2. Note: We recommend using flash attention v2.) |
| `mosaicml/llm-foundry:2.1.0_cu121_aws-latest` | 2.1.0 | 12.1 (EFA) | Yes (flash attention v1. Warning: Support for flash attention v1 has been deprecated.) |
| `mosaicml/llm-foundry:2.1.0_cu121_flash2_aws-latest` | 2.1.0 | 12.1 (EFA) | Yes (flash attention v2. Note: We recommend using flash attention v2.) |
| `mosaicml/pytorch:2.1.2_cu121-python3.10-ubuntu20.04` | 2.1.2 | 12.1 (Infiniband) | No |
| `mosaicml/llm-foundry:2.1.2_cu121_flash2-latest` | 2.1.2 | 12.1 (Infiniband) | Yes |
| `mosaicml/llm-foundry:2.1.2_cu121_flash2_aws-latest` | 2.1.2 | 12.1 (EFA) | Yes |


# Installation
Expand Down
70 changes: 59 additions & 11 deletions llmfoundry/data/finetuning/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,10 @@ def preprocessing_fn(example: Dict) -> Dict[str, str]:
import logging
import os
import warnings
from collections.abc import Mapping
from functools import partial
from pathlib import Path
from typing import (Any, Callable, Dict, List, Literal, Optional, Sequence,
from typing import (Any, Callable, Dict, List, Literal, Optional, Sequence, Set,
Tuple, Union, cast)

import datasets as hf_datasets
Expand All @@ -55,13 +56,18 @@ def preprocessing_fn(example: Dict) -> Dict[str, str]:

_ALLOWED_RESPONSE_KEYS = {'response', 'completion'}
_ALLOWED_PROMPT_KEYS = {'prompt'}
_ALLOWED_MESSAGES_KEYS = {'messages'}
_ALLOWED_ROLE_KEYS = {'role'}
_ALLOWED_CONTENT_KEYS = {'content'}
_ALLOWED_ROLES = {'user', 'assistant', 'system'}
_ALLOWED_LAST_MESSAGE_ROLES = {'assistant'}
DOWNLOADED_FT_DATASETS_DIRPATH = os.path.abspath(
os.path.join(os.path.realpath(__file__), os.pardir, os.pardir, os.pardir,
'.downloaded_finetuning'))
SUPPORTED_EXTENSIONS = ['.csv', '.jsonl', '.parquet']

PromptResponseDict = Dict[str, str]
ChatFormattedDict = Dict[str, List[Dict[str, str]]]
PromptResponseDict = Mapping[str, str]
ChatFormattedDict = Mapping[str, List[Dict[str, str]]]
Example = Union[PromptResponseDict, ChatFormattedDict]
ExampleType = Literal['prompt_response', 'chat']
TokenizedExample = Dict[str, List[int]]
Expand All @@ -79,7 +85,11 @@ def _get_example_type(example: Example) -> ExampleType:
Raises:
KeyError: If the example type is unknown.
"""
if 'messages' in example:
if not isinstance(example, Mapping):
raise TypeError(
f'Expected example to be a Mapping, but found {type(example)}')
if any(allowed_message_key in example
for allowed_message_key in _ALLOWED_MESSAGES_KEYS):
return 'chat'
elif any([
pr in example
Expand All @@ -102,6 +112,49 @@ def _is_empty_or_nonexistent(dirpath: str) -> bool:
return not os.path.isdir(dirpath) or len(os.listdir(dirpath)) == 0


def _get_key(dictionary: Mapping[str, Any], allowed_keys: Set[str]):
if not isinstance(dictionary, Mapping):
raise TypeError(
f'Expected dictionary to be a mapping, but found {type(dictionary)}'
)
desired_keys = allowed_keys.intersection(dictionary.keys())
if len(desired_keys) != 1:
raise ValueError(
f'Dictionary has multiple keys in `allowed_keys`: {desired_keys}')
return list(desired_keys)[0]


def _validate_chat_formatted_example(example: ChatFormattedDict):
if not isinstance(example, Mapping):
raise TypeError(
f'Expected example to be a mapping, but found {type(example)}')
messages = example[_get_key(example, _ALLOWED_MESSAGES_KEYS)]
if not isinstance(messages, List):
raise TypeError(
f'Expected messages to be an iterable, but found {type(messages)}')
if len(messages) <= 1:
raise ValueError('Chat example must have at least two messages')

last_message = messages[-1]
role_key = _get_key(last_message, _ALLOWED_ROLE_KEYS)
last_role = last_message[role_key]
if last_role not in _ALLOWED_LAST_MESSAGE_ROLES:
raise ValueError(f'Invalid last message role: {last_role}')

for message in messages:
role_key, content_key = _get_key(message, _ALLOWED_ROLE_KEYS), _get_key(
message, _ALLOWED_CONTENT_KEYS)
if len(message.keys()) != 2:
raise ValueError(
f'Expected 2 keys in message, but found {len(message.keys())}')
if message[role_key] not in _ALLOWED_ROLES:
raise ValueError(f'Invalid role: {message[role_key]}')
if not isinstance(message[content_key], str):
raise TypeError(
f'Expected content to be a string, but found {type(message[content_key])}'
)


def _slice_chat_formatted_example(
example: ChatFormattedDict,
tokenizer: PreTrainedTokenizerBase) -> Tuple[str, str]:
Expand All @@ -118,18 +171,13 @@ def _slice_chat_formatted_example(
ValueError: If the chat example has less than two messages or if the last message is not from the assistant.
KeyError: If a message does not have a role or content.
"""
messages = example['messages']
_validate_chat_formatted_example(example)
messages = example[_get_key(example, _ALLOWED_MESSAGES_KEYS)]

if len(messages) < 2:
raise ValueError(
f'chat example must have at least two messages. {messages=}')
last_message = messages[-1]
if last_message['role'] != 'assistant':
raise ValueError(
f'last message must be from assistant. {last_message=}')
for message in messages:
if 'role' not in message or 'content' not in message:
raise KeyError(f'message must have role and content. {message=}')

full_conversation = tokenizer.apply_chat_template(messages, tokenize=False)
prompt = tokenizer.apply_chat_template(messages[:-1],
Expand Down
14 changes: 14 additions & 0 deletions llmfoundry/utils/builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def build_evaluators(
device_eval_batch_size: int,
icl_seq_len: int,
icl_subset_num_batches: Optional[int],
fewshot_random_seed: Optional[int] = 1234,
) -> Tuple[List[Evaluator], List[str], Optional[EvalGauntlet]]:

evaluators = []
Expand All @@ -72,6 +73,7 @@ def build_evaluators(
tokenizer,
device_eval_batch_size,
icl_seq_len,
fewshot_random_seed,
icl_subset_num_batches,
)
evaluators.extend(icl_evaluators)
Expand Down Expand Up @@ -128,13 +130,15 @@ def build_icl_data_and_gauntlet(
tokenizer: PreTrainedTokenizerBase,
device_eval_batch_size: int,
icl_seq_len: int,
fewshot_random_seed: Optional[int] = 1234,
icl_subset_num_batches: Optional[int] = None
) -> Tuple[List[Evaluator], List[str], Optional[EvalGauntlet]]:
icl_evaluators, logger_keys = build_icl_evaluators(
icl_tasks_config,
tokenizer,
icl_seq_len,
device_eval_batch_size,
fewshot_random_seed=fewshot_random_seed,
icl_subset_num_batches=icl_subset_num_batches)
eval_gauntlet_cb = None
if eval_gauntlet_config is not None:
Expand Down Expand Up @@ -442,6 +446,7 @@ def build_icl_evaluators(
default_max_seq_len: int,
default_batch_size: int,
destination_dir: Optional[str] = None,
fewshot_random_seed: Optional[int] = 1234,
icl_subset_num_batches: Optional[int] = None,
) -> Tuple[List[Evaluator], List[str]]:
if destination_dir is None:
Expand Down Expand Up @@ -516,6 +521,10 @@ def _validate_cfg(icl_cfg: DictConfig):
if dist.get_local_rank() == 0 and os.path.exists(destination_path):
os.remove(destination_path)
dist.barrier()

hf_parsing_map = icl_cfg.get('hf_parsing_map', {})
hf_loading_vars = icl_cfg.get('hf_loading_vars', {})

early_stopping_criteria = icl_cfg.get('early_stopping_criteria',
None)
if isinstance(early_stopping_criteria, ListConfig):
Expand All @@ -533,13 +542,18 @@ def _validate_cfg(icl_cfg: DictConfig):
num_fewshot=num_fewshot,
prompt_string=icl_cfg.prompt_string,
example_delimiter=icl_cfg.example_delimiter,
hf_loading_vars=hf_loading_vars,
hf_parsing_map=hf_parsing_map,
continuation_delimiter=icl_cfg.continuation_delimiter,
question_prelimiter=icl_cfg.get('question_prelimiter', ''),
destination_path=destination_path,
fewshot_random_seed=icl_cfg.get('fewshot_random_seed',
fewshot_random_seed),
pass_at_k=icl_cfg.pass_at_k,
generations_per_sample=icl_cfg.generations_per_sample,
has_categories=icl_cfg.get('has_categories', False),
cot_delimiter=icl_cfg.get('cot_delimiter', ''),
generation_kwargs=icl_cfg.get('generation_kwargs', {}),
early_stopping_criteria=early_stopping_criteria,
do_normalization=icl_cfg.get('do_normalization', True),
generation_kwargs=icl_cfg.get('generation_kwargs'))
Expand Down
1 change: 1 addition & 0 deletions scripts/eval/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ def evaluate_model(
tokenizer=tokenizer,
device_eval_batch_size=device_eval_batch_size,
icl_seq_len=max_seq_len,
fewshot_random_seed=seed,
icl_subset_num_batches=icl_subset_num_batches,
)

Expand Down
42 changes: 42 additions & 0 deletions scripts/eval/local_data/EVAL_GAUNTLET.md
Original file line number Diff line number Diff line change
Expand Up @@ -253,3 +253,45 @@ Programming tasks evaluate the model's ability to understand code, write functio
- Year released: 2022
- Number of few shot examples: 0
- Random baseline accuracy: 0%
54. HumanEval Python 25% code generation
- Description: HumanEval Python 25% is an easier variant of HumanEval Python in which in addition to the original method signature, the model is also provided 25% of the lines in the canonical solution and expected to complete the reaminder of the program. It consists of 164 samples.
- Year released: 2023
- Number of few shot examples: 0
- Random baseline accuracy: 0%
55. HumanEval Python 50% code generation
- Description: HumanEval Python 50% is an easier variant of HumanEval Python in which in addition to the original method signature, the model is also provided 50% of the lines in the canonical solution and expected to complete the reaminder of the program. It consists of 164 samples.
- Year released: 2023
- Number of few shot examples: 0
- Random baseline accuracy: 0%
56. HumanEval Python 75% code generation
- Description: HumanEval Python 75% is an easier variant of HumanEval Python in which in addition to the original method signature, the model is also provided 75% of the lines in the canonical solution and expected to complete the reaminder of the program. It consists of 164 samples.
- Year released: 2023
- Number of few shot examples: 0
- Random baseline accuracy: 0%
57. HumanEval Python simple return statement code generation
- Description: HumanEval Python simple return statament is an easier variant of HumanEval Python in which the model is provided all of the canonical solution with the exception of the return statement and is expected to complete the return statement. Additionally, this set contains only the problems for which the canonical solution has a "simple" return statement consisting only of a line of the form `return VARIABLE\_NAME`. There are 37 samples.
- Year released: 2023
- Number of few shot examples: 0
- Random baseline accuracy: 0%
58. HumanEval Python complex return statement code generation
- Description: HumanEval Pythom complex return statament is an easier variant of HumanEval Python in which the model is provided all of the canonical solution with the exception of the return statement and is expected to complete the return statement. Additionally, this set contains only the problems for which the canonical solution does not have a "simple" return statement as defined above. There are 127 samples.
- Year released: 2023
- Number of few shot examples: 0
- Random baseline accuracy: 0%

### Long Context Gauntlet

We've included three different tasks for long (> 4000 tokens) context length evals. They are meant as litmus tests for a model's ability to properly utilize it's longer context length, which is often the result of fine-tuning after pre-training. For some of these datasets, we explicitly create sets where the required information is located in different sections of the input context, either the beginning, middle, or end of the input context.

1. HotPotQAXL
- Description: (HotPotQA)[https://hotpotqa.github.io/] is originally a dataset of ten documents and a question requiring comprehension of one or more of the supplied documents. The non-related documents are completely unrelated and called "distractor" documents. To extend this to longer context lengths, we randomly sample documents from the full set of documents across the dataset, adding them to the current datapoint until the set of documents and its question fills the current context length. We insert the "gold" document(s) (the document(s) containing the information that answers the question) within the first third, second third, or last third of the context length.
- Lengths: 2k, 4k, 8k, 16k, 32k, 64k
- Locations: beginning, middle, end
2. Key Value Pairs (Needle In a Haystack)
- Description: We construct a `.json` of key value pairs, where both the key and value are random hashes, in the style of (Lost in the Middle)[https://github.com/nelson-liu/lost-in-the-middle]. We ask the model to produce a value given a key from a specific key value pair found int he json. The pair is correspondingly located in the first third, second third, or last third of the json.
- Lengths: 2k, 4k, 8k, 16k, 32k, 64k
- Locations: beginning, middle, end
2. WikiQA Numeric
- Description: (WikiQA Numeric)[https://huggingface.co/datasets/abacusai/WikiQA-Altered_Numeric_QA] is a Wikipedia Question Answering dataset with a focus on questions with numeric answers. We preprocess the data only to easily parse it for our framework.
- Lengths: 2k, 4k, 8k, 16k
- Locations: N/A
Loading

0 comments on commit 7541f89

Please sign in to comment.