-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathtask_config.py
332 lines (287 loc) · 13.6 KB
/
task_config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
import json
from dataclasses import asdict, dataclass, field
from pathlib import Path
from typing import Dict, List, Literal, Optional, Tuple, Union
import _jsonnet
import dataclass_factory
from genbench.api import TaskType
from genbench.utils.validation import is_valid_url
@dataclass
class DataSourceConfig:
"""
Configuration class for specifying the data source.
Parameters:
type (`Literal["hf", "manual"]`):
We allow submissions involving HuggingFace datasets or publicly accessible URIs to
dataset files hosted with a date stamp (only one option is allowed).
Option 1 ("hf"): Use a HuggingFace dataset, Option 2 ("manual"): Use a publicly accessible URI
hf_id (`Optional[Union[str, Tuple[str, ...]]]`, optional):
HuggingFace dataset id. e.g. 'snli', or ('super_glue', 'MultiRC') in case of datasets that
are part of benchmarks. Only needed if `type` == "hf".
git_commit_sha (`Optional[str]`, optional):
Git commit sha of the data source.
To ensure the HF dataset is read from the same version
every time, you need to specify the commit SHA of HF dataset.
You can find them in https://huggingface.co/datasets/<dataset-name>/commits/main
Only needed if `type` == "hf".
test (`Optional[str]`, optional):
Test set URL. Only needed if `type` == "manual".
train (`Optional[str]`, optional):
Train set URL. Only needed if `type` == "manual".
validation (`Optional[str]`, optional):
Validation set URL. Only needed if `type` == "manual".
Raises:
AssertionError: If the data source type is "hf" and either `hf_id` or `git_commit_sha` is None.
AssertionError: If the data source type is "manual" and any of the URLs provided is invalid.
"""
type: Literal["hf", "manual"] = field(metadata={"help": "Type of the data source. e.g. 'hf'"})
hf_id: Optional[Union[str, Tuple[str, str]]] = field(
default=None, metadata={"help": "HuggingFace dataset id. e.g. 'glue'"}
)
git_commit_sha: Optional[str] = field(
default=None,
metadata={"help": "Git commit sha of the data source. e.g. '070042b....'"},
)
test: Optional[str] = field(
default=None,
metadata={"help": "Test set URL. e.g. 'https://example.com/test.jsonl'"},
)
train: Optional[str] = field(
default=None,
metadata={"help": "Train set URL. e.g. 'https://example.com/train.jsonl'"},
)
validation: Optional[str] = field(
default=None,
metadata={"help": "Validation set URL. e.g. 'https://example.com/val.jsonl'"},
)
def __post_init__(self):
if self.type == "hf":
assert self.hf_id is not None
assert self.git_commit_sha is not None
assert isinstance(self.hf_id, str) or (isinstance(self.hf_id, tuple) and len(self.hf_id) == 2)
elif self.type == "manual":
assert self.test is not None
assert all([(url is None) or (is_valid_url(url)) for url in [self.test, self.train, self.validation]])
else:
raise ValueError(f"Invalid value for data source type: {self.type}. Must be one of ['hf', 'manual']")
@dataclass
class EvaluationMetricConfig:
"""
Represents the configuration for specifying an evaluation metric.
Attributes:
hf_id (`Union[str, Tuple[str, str]]`):
The HuggingFace metric identifier. It can be a single string (e.g. 'accuracy') or
a tuple of two strings.
git_commit_sha (`str`):
The git commit SHA hash corresponding to the specific version of the metric (e.g. '070042b....').
best_score (`float`):
The highest possible value that can be achieved by the metric, typically used
for reference (e.g. 1.0).
compute_extra_kwargs (`Optional[Dict[str, str]]`, optional):
Additional keyword arguments to be passed to the metric's compute method (default is None).
Raises:
AssertionError: If `hf_id` or `git_commit_sha` is None.
AssertionError: If `hf_id` is not a string or a tuple of two strings.
AssertionError: If `best_score` is None.
"""
hf_id: Union[str, Tuple[str, str]] = field(metadata={"help": "HuggingFace metric id. e.g. 'accuracy'"})
git_commit_sha: str = field(metadata={"help": "Git commit sha of the metric. e.g. '070042b....'"})
best_score: float = field(metadata={"help": "Best value of the metric. e.g. 1.0"})
compute_extra_kwargs: Optional[Dict[str, str]] = field(
default=None,
metadata={"help": "Extra kwargs to pass to the metric's compute method."},
)
def __post_init__(self):
assert self.hf_id is not None
assert self.git_commit_sha is not None
assert isinstance(self.hf_id, str) or (isinstance(self.hf_id, tuple) and len(self.hf_id) == 2)
assert self.best_score is not None, "Best value must be specified."
@dataclass
class PromptBuilderConfig:
"""
Configuration class for building prompts for generative tasks.
This configuration follows the options for prompt construction as defined in BIG-bench:
https://github.com/google/BIG-bench/blob/main/docs/doc.md#optional-fields
Attributes:
instruction_zero_shot (`str`, optional):
Instruction to be prepended to the model's input in zero-shot setting. Defaults to an empty string.
instruction_few_shot (`str`, optional):
Instruction to be prepended to the model's input in few-shot setting. Defaults to an empty string.
input_prefix (`str`, optional):
Prefix to be added before the input in the model's prompt. Defaults to "Q: ".
output_prefix (`str`, optional):
Prefix to be added before the output in the model's prompt. Defaults to "\nA: ".
append_choices_to_input (`bool`, optional):
Whether to append the choices to the model's input. Defaults to True.
choices_prefix (`str`, optional):
Prefix to be added before the choices in the model's prompt. Defaults to "\nchoices: \n".
choice_item_postfix (`str`, optional):
Separator to be added between choice items. Defaults to "\n".
choice_item_prefix (`str`, optional):
Prefix to be added before each choice item. Defaults to "- ".
sequence_labeling_separator (`str`, optional):
Separator between tokens in sequence labeling tasks. Defaults to ",".
permute_choices (`bool`, optional):
Whether to permute the order of choices. Defaults to False.
few_shot_example_separator (`str`, optional):
Separator between few-shot examples. Defaults to "\n\n".
stop_string (`str`, optional):
String to indicate the end of the generation. Defaults to "\n\n".
"""
instruction_zero_shot: str = field(
default="",
metadata={"help": "Instruction of the task. Will be prepended to the model's input. e.g. 'Add two numbers:'"},
)
instruction_few_shot: str = field(
default="",
metadata={
"help": (
"Instruction of the task. Will be prepended to the"
" model's input. e.g. 'Add two numbers. Here are some examples:'"
)
},
)
input_prefix: str = field(default="Q: ", metadata={"help": "Prefix of the model's input."})
output_prefix: str = field(default="\nA: ", metadata={"help": "Prefix of the model's output."})
append_choices_to_input: bool = field(
default=True,
metadata={"help": "Whether to append the choices to the model's input."},
)
choices_prefix: str = field(
default="\nchoices: \n",
metadata={"help": "Prefix of the model's choice."},
)
choice_item_postfix: str = field(
default="\n",
metadata={"help": "Separator between the choices."},
)
choice_item_prefix: str = field(
default="- ",
metadata={"help": "Prefix of the model's choice item."},
)
sequence_labeling_separator: str = field(
default=",",
metadata={"help": "Separator between the sequence labeling tokens."},
)
permute_choices: bool = field(
default=False,
metadata={"help": "Whether to permute the choices."},
)
few_shot_example_separator: str = field(
default="\n\n",
metadata={"help": "Separator between the few-shot examples."},
)
stop_string: str = field(
default="\n\n",
metadata={"help": "Stop string to indicate the end of the generation."},
)
@dataclass
class PromptBaseTestingConfig:
prompt_builder: PromptBuilderConfig = field(metadata={"help": "Prompt builder configuration."})
@dataclass
class PreparationStrategiesConfig:
prompt_based_testing: Optional[PromptBaseTestingConfig] = field(
default=None, metadata={"help": "Prompt base testing configuration."}
)
@dataclass
class TaskConfig:
"""
Configuration class for defining a task.
Parameters:
name (str):
Name of the task. e.g. 'Addition'
description (str):
Description of the task. e.g. 'Addition of two numbers'
keywords (List[str]):
Keywords of the task
authors (List[str]):
Authors of the task
data_source (DataSourceConfig):
Data source configuration
has_validation_set (bool, optional):
Whether the task provides a validation set. Defaults to False.
has_train_set (bool, optional):
Whether the task provides a train set. Defaults to False.
task_type (Literal["free_form", "multi_choice", "sequence_labeling"]):
Type of the task. e.g. 'free_form'
field_mapping (Optional[Dict[str, str]], optional):
Mapping from the fields in the data source to the fields that the task ('input', 'target') expects.
Defaults to None.
free_form_output_regex (Optional[str], optional):
Regex to extract the output from the free form answer. Defaults to None.
split_file (Optional[str], optional):
Path to the split file. Defaults to None.
evaluation_metrics (Optional[List[EvaluationMetricConfig]], optional):
Evaluation metric configuration. Defaults to None.
preparation_strategies (PreparationStrategiesConfig):
Preparation strategies configuration.
"""
name: str = field(metadata={"help": "Name of the task. e.g. 'Addition'"})
description: str = field(metadata={"help": "Description of the task. e.g. 'Addition of two numbers'"})
keywords: List[str] = field(
metadata={"help": "Keywords of the task"},
)
authors: List[str] = field(
metadata={"help": "Authors of the task"},
)
data_source: DataSourceConfig = field(
metadata={"help": "Data source configuration"},
)
task_type: TaskType = field(
metadata={"help": "Type of the task. e.g. 'free_form'"},
)
preparation_strategies: PreparationStrategiesConfig = field(
metadata={"help": "Preparation strategies configuration."}
)
field_mapping: Optional[Dict[str, str]] = field(
default=None,
metadata={
"help": (
"Mapping from the fields in the data source to the fields that the task ('input','target') expects."
)
},
)
evaluation_metrics: Optional[List[EvaluationMetricConfig]] = field(
default=None,
metadata={"help": "Evaluation metric configuration"},
)
split_file: Optional[str] = field(
default=None,
metadata={"help": "split filename"},
)
free_form_output_regex: Optional[str] = field(
default=None,
metadata={"help": "Regex to extract the output from the free form answer"},
)
has_validation_set: bool = field(
default=False,
metadata={"help": "Whether the task provides a validation set"},
)
has_train_set: bool = field(
default=False,
metadata={"help": "Whether the task provides a train set"},
)
def __post_init__(self):
if self.task_type == "free_form" and self.free_form_output_regex is None:
raise ValueError("Task type is free_form but no free_form_output_regex is provided.")
if self.field_mapping is not None:
assert "input" in self.field_mapping, "Field mapping must contain 'input' field."
assert "target" in self.field_mapping, "Field mapping must contain 'target' field."
assert self.keywords is not None and len(self.keywords) > 0, "Keywords must be provided for the task."
assert self.authors is not None and len(self.authors) > 0, "Authors must be provided for the task."
assert self.preparation_strategies is not None, "Preparation strategies must be provided for the task."
@staticmethod
def from_jsonnet(jsonnet_str: Optional[str] = None, jsonnet_path: Optional[Path] = None) -> "TaskConfig":
if jsonnet_str is None and jsonnet_path is None:
raise ValueError("Either jsonnet_str or jsonnet_path must be provided.")
elif jsonnet_str is not None and jsonnet_path is not None:
raise ValueError("Only one of jsonnet_str or jsonnet_path must be provided.")
if jsonnet_str is None:
jsonnet_str = jsonnet_path.read_text()
json_str = _jsonnet.evaluate_snippet("snippet", jsonnet_str)
json_dict = json.loads(json_str)
factory = dataclass_factory.Factory()
config: TaskConfig = factory.load(json_dict, TaskConfig)
return config
def to_json(self, path: Path) -> None:
path.write_text(json.dumps(asdict(self), indent=4))