Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add text2sql tasks #1414

Merged
merged 148 commits into from
Jan 23, 2025
Merged
Changes from 1 commit
Commits
Show all changes
148 commits
Select commit Hold shift + click to select a range
8b26a9e
add text2sql templates
perlitz Dec 4, 2024
4752c56
add data managment utility for text2sql
perlitz Dec 4, 2024
0713ff3
add basic template
perlitz Dec 4, 2024
7909077
add sql execution accuracy metric
perlitz Dec 4, 2024
4fdab71
Merge branch 'main' into add-text2sql
perlitz Dec 13, 2024
61a9232
add text2sql execution accuracy metric
perlitz Dec 4, 2024
94f10c4
add text2sql task
perlitz Dec 6, 2024
9a90f90
condition download in presence of a cache dir
perlitz Dec 6, 2024
b37b467
add init fille
perlitz Dec 6, 2024
8d4894d
add processors
perlitz Dec 6, 2024
bc0d165
add processors
perlitz Dec 6, 2024
a185342
add basic template
perlitz Dec 6, 2024
f93eee9
change id to int
perlitz Dec 13, 2024
97c1bef
change notations in templates
perlitz Dec 13, 2024
3927e20
push to catalog
perlitz Dec 13, 2024
6a50032
add evidence, remove SL
perlitz Dec 13, 2024
cec65fd
remove unued function, fix
perlitz Dec 13, 2024
94c9c1e
fix imports from unitxt.text2sql
perlitz Dec 13, 2024
e5eb4a3
push to catalog
perlitz Dec 13, 2024
77eab83
fix cache location
perlitz Dec 13, 2024
2239ed6
add example
perlitz Dec 13, 2024
982d54d
fix imports
perlitz Dec 16, 2024
9a321e1
Merge branch 'main' into add-text2sql
perlitz Dec 16, 2024
2e337ad
add func_timeout to test reqs
perlitz Dec 16, 2024
c132d7d
fix typing
perlitz Dec 16, 2024
0cec726
change template name
perlitz Dec 16, 2024
dfa1af8
push to catalog
perlitz Dec 16, 2024
c857513
add req
perlitz Dec 18, 2024
9c566cc
add local model option
perlitz Dec 18, 2024
57f41f1
Merge branch 'main' into add-text2sql
perlitz Dec 18, 2024
67c9b4e
fix databases download
perlitz Dec 18, 2024
4a013aa
fix databases download
perlitz Dec 18, 2024
2c5fe5d
add loader limit ot make example faster
perlitz Dec 18, 2024
02f1b23
fix cache paths, avoid re-download
perlitz Dec 18, 2024
1854c25
add type schema
perlitz Dec 18, 2024
c83c319
remove inports from inits
perlitz Dec 18, 2024
d51a6d7
add text2sql to inits
perlitz Dec 18, 2024
82e1fe8
update card to use serializers
perlitz Dec 18, 2024
98bc231
add schema serializer
perlitz Dec 18, 2024
2bce256
add text2sql serializer to default template
perlitz Dec 18, 2024
3b4c23a
add schema to task
perlitz Dec 18, 2024
5d9112f
adjust templates to using serializer
perlitz Dec 18, 2024
3a9bccc
adjust templates to using serializer
perlitz Dec 18, 2024
9fda158
fix processor
perlitz Dec 18, 2024
ac3ebee
remove target prefix from template
perlitz Dec 19, 2024
f313a8b
add shuffle to bird
perlitz Dec 19, 2024
e333d27
add shuffle to bird
perlitz Dec 19, 2024
3e23e4c
edit template
perlitz Dec 19, 2024
0d18070
remove comment from init
perlitz Dec 19, 2024
9fce798
clear processors code
perlitz Dec 19, 2024
ce38e3a
add option with ticks
perlitz Dec 19, 2024
38639c1
add anls metric
perlitz Dec 19, 2024
2d7aa81
Merge branch 'main' into add-text2sql
perlitz Jan 6, 2025
40f3a56
add template
perlitz Dec 20, 2024
980556c
drop comment
perlitz Jan 6, 2025
84e4695
remove recursion limit
perlitz Jan 6, 2025
4793e7c
add loader_limit to example
perlitz Jan 6, 2025
a68ead5
fix recursion error
perlitz Jan 6, 2025
29f2505
move import to withing metric
perlitz Jan 6, 2025
fccbfd3
remove catalog files wo prepare
perlitz Jan 6, 2025
543f716
fix typing
perlitz Jan 6, 2025
5512c9e
change template im example
perlitz Jan 6, 2025
aa4cac5
moving text2sql implementaion to the main src dir
perlitz Jan 6, 2025
92aec0c
fix imports
perlitz Jan 6, 2025
a1a197a
fix imports
perlitz Jan 6, 2025
0aaac1d
fix imports
perlitz Jan 6, 2025
b0a4c7b
fix imports
perlitz Jan 6, 2025
fe9cd1e
import data_utils
perlitz Jan 6, 2025
342b7c5
Merge branch 'main' into add-text2sql
perlitz Jan 7, 2025
b6da498
Merge branch 'main' into add-text2sql
perlitz Jan 8, 2025
3a8de12
fix formatting
perlitz Jan 8, 2025
89b0ce0
refactor names
perlitz Jan 8, 2025
52982a6
add processors tests
perlitz Jan 8, 2025
cac3983
Merge branch 'main' into add-text2sql
perlitz Jan 8, 2025
0966eb7
add more tests
perlitz Jan 8, 2025
f5f4b50
add tests
perlitz Jan 8, 2025
2ed595e
refactor: allow more data sources
perlitz Jan 9, 2025
32c834b
allow db source input
perlitz Jan 9, 2025
b499715
organize imports
perlitz Jan 9, 2025
d75ecc6
update example
perlitz Jan 9, 2025
2fdab7b
add db_type to task
perlitz Jan 10, 2025
c02096b
format
perlitz Jan 10, 2025
5d689b5
add db_type to task
perlitz Jan 10, 2025
1317158
add local db definition ability
perlitz Jan 10, 2025
c3d5a2a
add EE tests
perlitz Jan 10, 2025
0fbbbb3
Merge branch 'main' into add-text2sql
perlitz Jan 10, 2025
b21c124
add tests
perlitz Jan 14, 2025
1f95e5a
rename file
perlitz Jan 14, 2025
4b8f029
rename file
perlitz Jan 14, 2025
52d8b84
update sql metric
perlitz Jan 14, 2025
e7222bf
rename file
perlitz Jan 14, 2025
9bd79cf
refactor types, serializers and metric
perlitz Jan 15, 2025
d651e9a
Merge branch 'main' into add-text2sql
perlitz Jan 15, 2025
9d41b4b
remove format_table
perlitz Jan 15, 2025
afe4121
add get schema for remove connector
perlitz Jan 15, 2025
a7f9f13
add tests for LocalConnector
perlitz Jan 15, 2025
fbb42fa
add tests for InMemoryDatabaseConnector
perlitz Jan 15, 2025
0ee81ea
add serializer tests
perlitz Jan 15, 2025
74525cb
remove fp test
perlitz Jan 15, 2025
14ddb22
fix serializer
perlitz Jan 15, 2025
ed2a422
make remote connector more robust
perlitz Jan 15, 2025
c6dae8d
Add schema serializer
perlitz Jan 15, 2025
69db33e
fix tests
perlitz Jan 15, 2025
09ba273
change error
perlitz Jan 15, 2025
59cf7ca
add data to bird card
perlitz Jan 16, 2025
24eb41d
fix tests
perlitz Jan 16, 2025
e0e82d6
add tests for db utils
perlitz Jan 16, 2025
540d3f8
fix metric test
perlitz Jan 16, 2025
cbd1fd7
pre-commit
perlitz Jan 16, 2025
baff237
Merge branch 'main' into add-text2sql
perlitz Jan 16, 2025
49f7548
Merge branch 'main' into add-text2sql
perlitz Jan 16, 2025
a434e9d
delete temp
perlitz Jan 16, 2025
e27e996
make id an str
perlitz Jan 17, 2025
7a2fcd9
fix acess to db
perlitz Jan 17, 2025
e274dbd
add empty template
perlitz Jan 17, 2025
22d9511
compare the results entry from the meric
perlitz Jan 17, 2025
daa28bb
reformat
perlitz Jan 17, 2025
f061bb8
reformat
perlitz Jan 17, 2025
6a72f74
make hint optional
perlitz Jan 20, 2025
7c4db9d
remove serializer exception
perlitz Jan 20, 2025
acd254f
change loggers
perlitz Jan 21, 2025
b2e3247
add API loader
perlitz Jan 21, 2025
69a3227
Merge branch 'main' into add-text2sql
perlitz Jan 21, 2025
e045339
fix examples
perlitz Jan 21, 2025
7eff16c
fix tests
perlitz Jan 21, 2025
e9366fd
pre-commit run
perlitz Jan 21, 2025
13f7d17
optimize EE metric
perlitz Jan 21, 2025
95bfeb8
Merge branch 'main' into add-text2sql
perlitz Jan 21, 2025
57a0eb1
fix test
perlitz Jan 21, 2025
be557ff
align bird card to string id
perlitz Jan 21, 2025
9ae0061
move metric tests
perlitz Jan 21, 2025
4fc611f
Merge branch 'main' into add-text2sql
elronbandel Jan 22, 2025
b6e84d2
add metric dependencies to pyproject
perlitz Jan 22, 2025
236d332
handle error in EE metric
perlitz Jan 22, 2025
d0b2201
Merge branch 'main' into add-text2sql
perlitz Jan 22, 2025
07ec859
Refactor database query execution with caching and improved error han…
perlitz Jan 22, 2025
b0098d8
fix template bug
perlitz Jan 22, 2025
dcc5223
fix template bug
perlitz Jan 22, 2025
f4b4f8f
remove anls metric
perlitz Jan 22, 2025
bd57f50
make result matching invariant to order
perlitz Jan 22, 2025
487c3dc
make it faster and restrict scores
perlitz Jan 22, 2025
700093f
fix remote metric
perlitz Jan 22, 2025
a704aac
fix tests
perlitz Jan 22, 2025
66adaf2
remove retry tests
perlitz Jan 22, 2025
b39b23d
format tests
perlitz Jan 22, 2025
ba6aa2d
remove some tests
perlitz Jan 22, 2025
6f85c44
return anls to metric
perlitz Jan 22, 2025
e576aa0
Merge branch 'main' into add-text2sql
elronbandel Jan 23, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
pre-commit run
Signed-off-by: Yotam-Perlitz <[email protected]>
perlitz committed Jan 21, 2025

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature.
commit e9366fdc42c33bb7a06577632f6b87968efc9fcb
6 changes: 3 additions & 3 deletions src/unitxt/db_utils.py
Original file line number Diff line number Diff line change
@@ -190,9 +190,9 @@ class RemoteDatabaseConnector(DatabaseConnector):
def __init__(self, db_config: SQLDatabase):
super().__init__(db_config)

assert db_config["db_id"], (
"db_id must be in db_config for RemoteDatabaseConnector"
)
assert db_config[
"db_id"
], "db_id must be in db_config for RemoteDatabaseConnector"
self.api_url, self.database_id = (
db_config["db_id"].split(",")[0],
db_config["db_id"].split("db_id=")[-1].split(",")[0],
19 changes: 9 additions & 10 deletions src/unitxt/loaders.py
Original file line number Diff line number Diff line change
@@ -572,15 +572,15 @@ def prepare(self):

def lazy_verify(self):
super().verify()
assert self.endpoint_url is not None, (
f"Please set the {self.endpoint_url_env} environmental variable"
)
assert self.aws_access_key_id is not None, (
f"Please set {self.aws_access_key_id_env} environmental variable"
)
assert self.aws_secret_access_key is not None, (
f"Please set {self.aws_secret_access_key_env} environmental variable"
)
assert (
self.endpoint_url is not None
), f"Please set the {self.endpoint_url_env} environmental variable"
assert (
self.aws_access_key_id is not None
), f"Please set {self.aws_access_key_id_env} environmental variable"
assert (
self.aws_secret_access_key is not None
), f"Please set {self.aws_secret_access_key_env} environmental variable"
if self.streaming:
raise NotImplementedError("LoadFromKaggle cannot load with streaming.")

@@ -1095,4 +1095,3 @@ def process(self) -> MultiStream:
self.__class__._loader_cache.max_size = settings.loader_cache_size
self.__class__._loader_cache[str(self)] = iterables
return MultiStream.from_iterables(iterables, copying=True)

202 changes: 99 additions & 103 deletions src/unitxt/metrics.py
Original file line number Diff line number Diff line change
@@ -1119,9 +1119,9 @@ def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generato
)

for reduction, fields in self.reduction_map.items():
assert reduction in self.implemented_reductions, (
f"Reduction {reduction} is not implemented, use one of {self.implemented_reductions}"
)
assert (
reduction in self.implemented_reductions
), f"Reduction {reduction} is not implemented, use one of {self.implemented_reductions}"

if reduction == "mean":
for field_name in fields:
@@ -1390,12 +1390,12 @@ class InstanceMetric(StreamOperator, MetricWithConfidenceInterval):
def _validate_group_mean_task_data(self, instance):
# instances need to all have task_data field with field group_id
assert "task_data" in instance, "each instance must have an task_data field"
assert isinstance(instance["task_data"], dict), (
"each instance must have an task_data field that is a dict"
)
assert "group_id" in instance["task_data"], (
"each instance task_data dict must have a key group_id"
)
assert isinstance(
instance["task_data"], dict
), "each instance must have an task_data field that is a dict"
assert (
"group_id" in instance["task_data"]
), "each instance task_data dict must have a key group_id"

def _validate_group_mean_reduction(self):
"""Ensure that group_mean reduction_map is properly formatted.
@@ -1448,40 +1448,40 @@ def accuracy_diff(subgroup_scores_dict, expected_subgroup_types=['original', 'pa
2 'Why are ants eating my food?' 'original'
"""
# validate the reduction_map
assert "group_mean" in self.reduction_map, (
"reduction_map must have a 'group_mean' key"
)
assert (
"group_mean" in self.reduction_map
), "reduction_map must have a 'group_mean' key"
fields = self.reduction_map["group_mean"]
# for group_mean, expects a dict
assert isinstance(fields, dict)
assert "agg_func" in fields, (
"fields should have a key 'agg_func' whose value is a 3-element list of a function name, function definition, and a boolean indicator"
)
assert isinstance(fields["agg_func"], list), (
"fields['agg_func'] should be a list"
)
assert len(fields["agg_func"]) == 3, (
"fields['agg_func'] should be a 3-element list"
)
assert isinstance(fields["agg_func"][0], str), (
"first item in fields['agg_func'] should be a string name of a function"
)
assert callable(fields["agg_func"][1]), (
"second item in fields['agg_func'] should be a callable function"
)
assert isinstance(fields["agg_func"][2], bool), (
"third item in fields['agg_func'] should be a boolean value"
)
assert (
"agg_func" in fields
), "fields should have a key 'agg_func' whose value is a 3-element list of a function name, function definition, and a boolean indicator"
assert isinstance(
fields["agg_func"], list
), "fields['agg_func'] should be a list"
assert (
len(fields["agg_func"]) == 3
), "fields['agg_func'] should be a 3-element list"
assert isinstance(
fields["agg_func"][0], str
), "first item in fields['agg_func'] should be a string name of a function"
assert callable(
fields["agg_func"][1]
), "second item in fields['agg_func'] should be a callable function"
assert isinstance(
fields["agg_func"][2], bool
), "third item in fields['agg_func'] should be a boolean value"
if "score_fields" in fields:
assert isinstance(fields["score_fields"], list)

def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
instance_scores = self.compute_instance_scores(stream)
global_score = {"num_of_instances": len(instance_scores)}
for reduction_type, reduction_params in self.reduction_map.items():
assert reduction_type in self.implemented_reductions, (
f"Reduction {reduction_type} is not implemented, use one of {self.implemented_reductions}"
)
assert (
reduction_type in self.implemented_reductions
), f"Reduction {reduction_type} is not implemented, use one of {self.implemented_reductions}"

field_name_full_prefix = ""
# used for passing to the bootstrapping, depends on whether the groups are fixed or not
@@ -1579,9 +1579,7 @@ def compute_instance_scores(
assert (
"task_data" in instance
and self.subgroup_column in instance["task_data"]
), (
f"each instance task_data dict must have a key {self.subgroup_column}"
)
), f"each instance task_data dict must have a key {self.subgroup_column}"

task_data = instance["task_data"] if "task_data" in instance else {}

@@ -2183,15 +2181,15 @@ def disable_confidence_interval_calculation(self):

def verify(self):
super().verify()
assert self.metric is not None, (
f"'metric' is not set in {self.get_metric_name()}"
)
assert self.main_score is not None, (
f"'main_score' is not set in {self.get_metric_name()}"
)
assert isinstance(self.metric, Metric), (
f"'metric' is not set to a Metric class in {self.get_metric_name()} (type{self.metric})"
)
assert (
self.metric is not None
), f"'metric' is not set in {self.get_metric_name()}"
assert (
self.main_score is not None
), f"'main_score' is not set in {self.get_metric_name()}"
assert isinstance(
self.metric, Metric
), f"'metric' is not set to a Metric class in {self.get_metric_name()} (type{self.metric})"
if self.postpreprocess_steps is not None:
depr_message = "Field 'postpreprocess_steps' is deprecated. Please use 'postprocess_steps' for the same purpose."
warnings.warn(depr_message, DeprecationWarning, stacklevel=2)
@@ -2212,9 +2210,9 @@ def prepare(self):
and isinstance(self.postprocess_steps, list)
and len(self.postprocess_steps) > 0
)
assert not (has_postpreprocess and has_postprocess), (
"Must define at most one of postpreprocess_steps (which is deprecated) and postprocess_steps (to be used from now on)"
)
assert not (
has_postpreprocess and has_postprocess
), "Must define at most one of postpreprocess_steps (which is deprecated) and postprocess_steps (to be used from now on)"
if has_postpreprocess:
self.postprocess_steps = self.postpreprocess_steps
self.prepare_score = SequentialOperator(
@@ -2289,16 +2287,14 @@ def verify(self):
Documentation.HUGGINGFACE_METRICS,
)

assert self.hf_additional_input_fields is None or isoftype(
self.hf_additional_input_fields, List[str]
), (
f"Argument hf_additional_input_fields should be either None or List[str]. It is now: {self.hf_additional_input_fields}."
)
assert self.hf_additional_input_fields_pass_one_value is None or isoftype(
self.hf_additional_input_fields_pass_one_value, List[str]
), (
f"Argument hf_additional_input_fields_pass_one_value should be either None or List[str]. It is now: {self.hf_additional_input_fields_pass_one_value}."
)
assert (
self.hf_additional_input_fields is None
or isoftype(self.hf_additional_input_fields, List[str])
), f"Argument hf_additional_input_fields should be either None or List[str]. It is now: {self.hf_additional_input_fields}."
assert (
self.hf_additional_input_fields_pass_one_value is None
or isoftype(self.hf_additional_input_fields_pass_one_value, List[str])
), f"Argument hf_additional_input_fields_pass_one_value should be either None or List[str]. It is now: {self.hf_additional_input_fields_pass_one_value}."

return super().verify()

@@ -2317,25 +2313,25 @@ def compute(
) -> dict:
passed_task_data = {}
for additional_input_field in self.hf_additional_input_fields:
assert additional_input_field in task_data[0], (
f"'{additional_input_field}' field required by {__class__.__name__} is not in passed in task_data: {task_data[0]}"
)
assert (
additional_input_field in task_data[0]
), f"'{additional_input_field}' field required by {__class__.__name__} is not in passed in task_data: {task_data[0]}"
passed_task_data[additional_input_field] = [
additional_input[additional_input_field]
for additional_input in task_data
]
for additional_input_field in self.hf_additional_input_fields_pass_one_value:
assert additional_input_field in task_data[0], (
f"'{additional_input_field}' field required by {__class__.__name__} is not in passed in task_data: {task_data[0]}"
)
assert (
additional_input_field in task_data[0]
), f"'{additional_input_field}' field required by {__class__.__name__} is not in passed in task_data: {task_data[0]}"

values = {
additional_input[additional_input_field]
for additional_input in task_data
}
assert len(values) == 1, (
f"Values of '{additional_input_field}' field required by {__class__.__name__} should all be the same, but have multiple values {values}"
)
assert (
len(values) == 1
), f"Values of '{additional_input_field}' field required by {__class__.__name__} should all be the same, but have multiple values {values}"

passed_task_data[additional_input_field] = next(iter(values))

@@ -2350,22 +2346,22 @@ def compute(
result[self.main_score] = float(result[self.hf_main_score])
del result[self.hf_main_score]
if self.scale != 1.0:
assert self.scaled_fields is not None, (
f"Scaling factor was set to {self.scale}, but no fields specified"
)
assert (
self.scaled_fields is not None
), f"Scaling factor was set to {self.scale}, but no fields specified"
for key in self.scaled_fields:
assert key in result, (
f"Trying to scale field '{key}' which is not in results of metrics: {result}"
)
assert (
key in result
), f"Trying to scale field '{key}' which is not in results of metrics: {result}"
if isinstance(result[key], list):
assert all(isinstance(v, float) for v in result[key]), (
"Not all scaled field '{key}' values are floats: {result[key]}"
)
assert all(
isinstance(v, float) for v in result[key]
), "Not all scaled field '{key}' values are floats: {result[key]}"
result[key] = [v / self.scale for v in result[key]]
else:
assert isinstance(result[key], float), (
"Scaled field '{key}' is not float: {result[key]}"
)
assert isinstance(
result[key], float
), "Scaled field '{key}' is not float: {result[key]}"
result[key] /= self.scale
if self.main_score in result:
result[self.main_score] = float(result[self.main_score])
@@ -2394,9 +2390,9 @@ def compute(
) -> List[Dict[str, Any]]:
passed_task_data = {}
for additional_input_field in self.hf_additional_input_fields:
assert additional_input_field in task_data[0], (
f"'{additional_input_field}' field required by {__class__.__name__} is not in passed in task_data: {task_data[0]}"
)
assert (
additional_input_field in task_data[0]
), f"'{additional_input_field}' field required by {__class__.__name__} is not in passed in task_data: {task_data[0]}"
passed_task_data[additional_input_field] = [
additional_input[additional_input_field]
for additional_input in task_data
@@ -2733,9 +2729,9 @@ def download_finqa_eval_script_file(url, local_path, hash_of_script):
response = requests.get(url)
response.raise_for_status()
content = response.content
assert hashlib.md5(content).hexdigest() == hash_of_script, (
f'URL ("{url}") is different than expected. Make sure you added the right one.'
)
assert (
hashlib.md5(content).hexdigest() == hash_of_script
), f'URL ("{url}") is different than expected. Make sure you added the right one.'

with open(local_path, "wb") as file:
file.write(content)
@@ -2869,9 +2865,9 @@ def compute(
labels=labels_param,
)
if isinstance(result[self.metric], numpy.ndarray):
assert len(result[self.metric]) == len(labels), (
f"F1 result ({result[self.metric]}) has more entries than labels ({labels})"
)
assert (
len(result[self.metric]) == len(labels)
), f"F1 result ({result[self.metric]}) has more entries than labels ({labels})"
final_result = {self.main_score: nan_mean(result[self.metric])}
for i, label in enumerate(labels):
final_result[self.metric + "_" + label] = result[self.metric][i]
@@ -4654,12 +4650,12 @@ def validate_subgroup_types(
for subgroup_name, score_list in subgroup_scores_dict.items()
}
)
assert isinstance(control_subgroup_types, list), (
"control_subgroup_types must be a list"
)
assert isinstance(comparison_subgroup_types, list), (
"comparison_subgroup_types must be a list"
)
assert isinstance(
control_subgroup_types, list
), "control_subgroup_types must be a list"
assert isinstance(
comparison_subgroup_types, list
), "comparison_subgroup_types must be a list"
# make sure each list is unique, so that labels aren't double-counted
control_subgroup_types = list(set(control_subgroup_types))
comparison_subgroup_types = list(set(comparison_subgroup_types))
@@ -4814,9 +4810,9 @@ def normalized_cohens_h(

# requires scores to be in [0,1]
for subgroup_name, score_list in subgroup_scores_dict.items():
assert all(0 <= score <= 1 for score in score_list), (
f"all {subgroup_name} scores must be in [0,1]"
)
assert all(
0 <= score <= 1 for score in score_list
), f"all {subgroup_name} scores must be in [0,1]"

# combine all scores from each label (if there are more than 1 in each group) into a list
group_scores_list = [
@@ -5620,9 +5616,9 @@ def prepare(self):

def create_ensemble_scores(self, instance):
score = self.ensemble(instance)
instance["prediction"] = (
score # We use here the prediction field to pass the score to the compute method.
)
instance[
"prediction"
] = score # We use here the prediction field to pass the score to the compute method.
return instance

def ensemble(self, instance):
@@ -5802,9 +5798,9 @@ def load_weights(json_file):
return json.load(file)

def ensemble(self, instance):
assert self.weights is not None, (
"RandomForestMetricsEnsemble must set self.weights before it can be used"
)
assert (
self.weights is not None
), "RandomForestMetricsEnsemble must set self.weights before it can be used"
ensemble_model = self.decode_forest(self.weights)

prediction_lst = []
6 changes: 3 additions & 3 deletions utils/.secrets.baseline
Original file line number Diff line number Diff line change
@@ -151,7 +151,7 @@
"filename": "src/unitxt/loaders.py",
"hashed_secret": "840268f77a57d5553add023cfa8a4d1535f49742",
"is_verified": false,
"line_number": 500,
"line_number": 502,
"is_secret": false
}
],
@@ -161,7 +161,7 @@
"filename": "src/unitxt/metrics.py",
"hashed_secret": "fa172616e9af3d2a24b5597f264eab963fe76889",
"is_verified": false,
"line_number": 2754,
"line_number": 2748,
"is_secret": false
}
],
@@ -184,5 +184,5 @@
}
]
},
"generated_at": "2025-01-17T15:22:31Z"
"generated_at": "2025-01-21T18:10:24Z"
}