From d0f977efb5dc2259e9bb01d2584f130aba5ffae3 Mon Sep 17 00:00:00 2001 From: Adel Basli Date: Mon, 18 Nov 2024 16:41:40 +0100 Subject: [PATCH] ran some PEP formatting on old python files (#359) --- GSoC2018/table_detection/resize.py | 18 +- .../table_detection/utils/label_map_util.py | 263 +-- .../utils/visualization_utils.py | 1443 +++++++++-------- ai-emlyon/eml/datanavig.py | 786 ++++----- ai-emlyon/eml/eda.py | 15 +- ai-emlyon/eml/model_eval.py | 107 +- ai-emlyon/eml/taxonomy.py | 252 +-- ai-emlyon/food_model/evaluator.py | 189 ++- ai-emlyon/food_model/xgfood.py | 208 ++- ai-emlyon/scripts/lucain_script_taxonomy.py | 6 +- ai-emlyon/scripts/merging_script.py | 222 ++- data4good_logo_detection/setup.py | 2 +- .../src/data/explore_data.py | 4 +- .../dataset-generation/clean_dataset.py | 4 +- .../model-analysis/evaluate_model.py | 15 +- .../model-analysis/test_prediction.py | 7 +- ingredient_extraction/train/run_pred_colab.py | 36 +- .../ANN_benchmark/elasticsearch_benchmark.py | 595 +++---- .../ANN_benchmark/faiss_benchmark.py | 176 +- .../ANN_benchmark/recall_computation.py | 105 +- .../ANN_benchmark/redis_benchmark.py | 326 ++-- logo-ann/benchmarks/ANN_benchmark/utils.py | 15 +- .../embedding_models_benchmark/main.py | 12 +- logo-ann/dataset/clean_logo_dataset.py | 1 - logo-ann/dataset/create_logo_dataset.py | 6 +- logo-ann/generation/02_generate_embeddings.py | 50 +- logo-ann/generation/03_generate_index.py | 22 +- logo-classifier/annotate_logos.py | 29 +- logo-classifier/dataset.py | 127 +- logo-classifier/settings.py | 2 +- logo-classifier/to_delete.py | 8 +- logo-classifier/train.py | 195 ++- logo-classifier/utils.py | 24 +- logo-classifier/visualize_logos.py | 86 +- .../cli/inference_yolo_tflite.py | 16 +- .../tensorflow_object_api/object_detection.py | 9 +- .../string_int_label_map_pb2.py | 4 +- ocr_cleaning/multiproc_ocr_cleaning.py | 1 + spellcheck/old/SessionState.py | 1 + spellcheck/old/evaluation/metrics.py | 13 +- spellcheck/old/ingredients.py | 1 + spellcheck/old/label.py | 5 +- spellcheck/old/mongo.py | 5 +- .../scripts/argilla/benchmark/add_records.py | 29 +- .../argilla/benchmark/deploy_benchmark.py | 3 +- .../argilla/benchmark/extract_benchmark.py | 73 +- .../argilla/benchmark/update_benchmark.py | 74 +- .../argilla/dataset/deploy_training.py | 5 +- spellcheck/scripts/batch/main.py | 103 +- .../scripts/benchmark/create_benchmark.py | 8 +- .../benchmark/create_test_benchmark.py | 20 +- ..._synthetic_data_for_additional_products.py | 18 +- .../scripts/dags/benchmark_generation.py | 36 +- spellcheck/scripts/dags/data_processing.py | 29 +- .../scripts/dags/extract_from_argilla.py | 38 +- .../dags/training/pretraining_finetuning.py | 114 +- spellcheck/scripts/dags/training/training.py | 88 +- spellcheck/scripts/dataset/0_extract_data.py | 34 +- .../dataset/1_generate_synthetic_data.py | 16 +- .../scripts/dataset/2_convert_to_dataset.py | 2 +- spellcheck/scripts/evaluation/evaluate.py | 78 +- .../scripts/old_to_new/0_convert_old_data.py | 9 +- .../scripts/training/flan-t5/flan-t5.py | 289 ++-- spellcheck/scripts/training/llm/llm.py | 60 +- .../scripts/training/llm/pretraining_llm.py | 59 +- spellcheck/src/spellcheck/__init__.py | 2 +- .../src/spellcheck/argilla/deployment.py | 243 ++- .../src/spellcheck/argilla/extraction.py | 71 +- spellcheck/src/spellcheck/config.py | 4 +- .../src/spellcheck/evaluation/evaluation.py | 61 +- .../src/spellcheck/evaluation/evaluator.py | 181 ++- spellcheck/src/spellcheck/model.py | 59 +- spellcheck/src/spellcheck/processing.py | 28 +- spellcheck/src/spellcheck/prompt.py | 5 +- spellcheck/src/spellcheck/spellcheck.py | 6 +- spellcheck/src/spellcheck/training/configs.py | 5 +- spellcheck/src/spellcheck/training/trainer.py | 216 ++- spellcheck/src/spellcheck/training/utils.py | 28 +- spellcheck/src/spellcheck/utils.py | 44 +- spellcheck/tests/test_argilla.py | 63 +- spellcheck/tests/test_evaluate.py | 353 ++-- spellcheck/tests/test_processing.py | 61 +- spellcheck/tests/test_utils.py | 17 +- 83 files changed, 4666 insertions(+), 3377 deletions(-) diff --git a/GSoC2018/table_detection/resize.py b/GSoC2018/table_detection/resize.py index 805c6a67..611d8a6e 100644 --- a/GSoC2018/table_detection/resize.py +++ b/GSoC2018/table_detection/resize.py @@ -2,35 +2,37 @@ import glob from PIL import Image + def get_filename(file): name_list = filename.split("/") name = name_list[-1].split(".") return name[0] + def resize(filename, nx, ny): imagename = get_filename(filename) img = im.resize((int(nx), int(ny)), Image.ANTIALIAS) img.save("test_images/{}.jpg".format(imagename), optimize=True, quality=95) + count = 0 -path = 'test_images/' -for filename in glob.glob(os.path.join(path, '*.jpg')): +path = "test_images/" +for filename in glob.glob(os.path.join(path, "*.jpg")): im = Image.open(filename) nx, ny = im.size - if(nx >= ny): + if nx >= ny: new_nx = 1000 ratio = new_nx / nx - new_ny = ratio * ny + new_ny = ratio * ny resize(filename, new_nx, new_ny) else: new_ny = 1000 ratio = new_ny / ny - new_nx = ratio * nx + new_nx = ratio * nx resize(filename, new_nx, new_ny) print(filename) - count+=1 - + count += 1 -print (count) \ No newline at end of file +print(count) diff --git a/GSoC2018/table_detection/utils/label_map_util.py b/GSoC2018/table_detection/utils/label_map_util.py index aef46c1d..121ac9d5 100644 --- a/GSoC2018/table_detection/utils/label_map_util.py +++ b/GSoC2018/table_detection/utils/label_map_util.py @@ -23,159 +23,164 @@ def _validate_label_map(label_map): - """Checks if a label map is valid. + """Checks if a label map is valid. - Args: - label_map: StringIntLabelMap to validate. + Args: + label_map: StringIntLabelMap to validate. - Raises: - ValueError: if label map is invalid. - """ - for item in label_map.item: - if item.id < 0: - raise ValueError('Label map ids should be >= 0.') - if (item.id == 0 and item.name != 'background' and - item.display_name != 'background'): - raise ValueError('Label map id 0 is reserved for the background label') + Raises: + ValueError: if label map is invalid. + """ + for item in label_map.item: + if item.id < 0: + raise ValueError("Label map ids should be >= 0.") + if ( + item.id == 0 + and item.name != "background" + and item.display_name != "background" + ): + raise ValueError("Label map id 0 is reserved for the background label") def create_category_index(categories): - """Creates dictionary of COCO compatible categories keyed by category id. + """Creates dictionary of COCO compatible categories keyed by category id. - Args: - categories: a list of dicts, each of which has the following keys: - 'id': (required) an integer id uniquely identifying this category. - 'name': (required) string representing category name - e.g., 'cat', 'dog', 'pizza'. + Args: + categories: a list of dicts, each of which has the following keys: + 'id': (required) an integer id uniquely identifying this category. + 'name': (required) string representing category name + e.g., 'cat', 'dog', 'pizza'. - Returns: - category_index: a dict containing the same entries as categories, but keyed - by the 'id' field of each category. - """ - category_index = {} - for cat in categories: - category_index[cat['id']] = cat - return category_index + Returns: + category_index: a dict containing the same entries as categories, but keyed + by the 'id' field of each category. + """ + category_index = {} + for cat in categories: + category_index[cat["id"]] = cat + return category_index def get_max_label_map_index(label_map): - """Get maximum index in label map. - - Args: - label_map: a StringIntLabelMapProto - - Returns: - an integer - """ - return max([item.id for item in label_map.item]) - - -def convert_label_map_to_categories(label_map, - max_num_classes, - use_display_name=True): - """Loads label map proto and returns categories list compatible with eval. - - This function loads a label map and returns a list of dicts, each of which - has the following keys: - 'id': (required) an integer id uniquely identifying this category. - 'name': (required) string representing category name - e.g., 'cat', 'dog', 'pizza'. - We only allow class into the list if its id-label_id_offset is - between 0 (inclusive) and max_num_classes (exclusive). - If there are several items mapping to the same id in the label map, - we will only keep the first one in the categories list. - - Args: - label_map: a StringIntLabelMapProto or None. If None, a default categories - list is created with max_num_classes categories. - max_num_classes: maximum number of (consecutive) label indices to include. - use_display_name: (boolean) choose whether to load 'display_name' field - as category name. If False or if the display_name field does not exist, - uses 'name' field as category names instead. - Returns: - categories: a list of dictionaries representing all possible categories. - """ - categories = [] - list_of_ids_already_added = [] - if not label_map: - label_id_offset = 1 - for class_id in range(max_num_classes): - categories.append({ - 'id': class_id + label_id_offset, - 'name': 'category_{}'.format(class_id + label_id_offset) - }) + """Get maximum index in label map. + + Args: + label_map: a StringIntLabelMapProto + + Returns: + an integer + """ + return max([item.id for item in label_map.item]) + + +def convert_label_map_to_categories(label_map, max_num_classes, use_display_name=True): + """Loads label map proto and returns categories list compatible with eval. + + This function loads a label map and returns a list of dicts, each of which + has the following keys: + 'id': (required) an integer id uniquely identifying this category. + 'name': (required) string representing category name + e.g., 'cat', 'dog', 'pizza'. + We only allow class into the list if its id-label_id_offset is + between 0 (inclusive) and max_num_classes (exclusive). + If there are several items mapping to the same id in the label map, + we will only keep the first one in the categories list. + + Args: + label_map: a StringIntLabelMapProto or None. If None, a default categories + list is created with max_num_classes categories. + max_num_classes: maximum number of (consecutive) label indices to include. + use_display_name: (boolean) choose whether to load 'display_name' field + as category name. If False or if the display_name field does not exist, + uses 'name' field as category names instead. + Returns: + categories: a list of dictionaries representing all possible categories. + """ + categories = [] + list_of_ids_already_added = [] + if not label_map: + label_id_offset = 1 + for class_id in range(max_num_classes): + categories.append( + { + "id": class_id + label_id_offset, + "name": "category_{}".format(class_id + label_id_offset), + } + ) + return categories + for item in label_map.item: + if not 0 < item.id <= max_num_classes: + logging.info( + "Ignore item %d since it falls outside of requested " "label range.", + item.id, + ) + continue + if use_display_name and item.HasField("display_name"): + name = item.display_name + else: + name = item.name + if item.id not in list_of_ids_already_added: + list_of_ids_already_added.append(item.id) + categories.append({"id": item.id, "name": name}) return categories - for item in label_map.item: - if not 0 < item.id <= max_num_classes: - logging.info('Ignore item %d since it falls outside of requested ' - 'label range.', item.id) - continue - if use_display_name and item.HasField('display_name'): - name = item.display_name - else: - name = item.name - if item.id not in list_of_ids_already_added: - list_of_ids_already_added.append(item.id) - categories.append({'id': item.id, 'name': name}) - return categories def load_labelmap(path): - """Loads label map proto. - - Args: - path: path to StringIntLabelMap proto text file. - Returns: - a StringIntLabelMapProto - """ - with tf.gfile.GFile(path, 'r') as fid: - label_map_string = fid.read() - label_map = string_int_label_map_pb2.StringIntLabelMap() - try: - text_format.Merge(label_map_string, label_map) - except text_format.ParseError: - label_map.ParseFromString(label_map_string) - _validate_label_map(label_map) - return label_map + """Loads label map proto. + + Args: + path: path to StringIntLabelMap proto text file. + Returns: + a StringIntLabelMapProto + """ + with tf.gfile.GFile(path, "r") as fid: + label_map_string = fid.read() + label_map = string_int_label_map_pb2.StringIntLabelMap() + try: + text_format.Merge(label_map_string, label_map) + except text_format.ParseError: + label_map.ParseFromString(label_map_string) + _validate_label_map(label_map) + return label_map def get_label_map_dict(label_map_path, use_display_name=False): - """Reads a label map and returns a dictionary of label names to id. + """Reads a label map and returns a dictionary of label names to id. - Args: - label_map_path: path to label_map. - use_display_name: whether to use the label map items' display names as keys. + Args: + label_map_path: path to label_map. + use_display_name: whether to use the label map items' display names as keys. - Returns: - A dictionary mapping label names to id. - """ - label_map = load_labelmap(label_map_path) - label_map_dict = {} - for item in label_map.item: - if use_display_name: - label_map_dict[item.display_name] = item.id - else: - label_map_dict[item.name] = item.id - return label_map_dict + Returns: + A dictionary mapping label names to id. + """ + label_map = load_labelmap(label_map_path) + label_map_dict = {} + for item in label_map.item: + if use_display_name: + label_map_dict[item.display_name] = item.id + else: + label_map_dict[item.name] = item.id + return label_map_dict def create_category_index_from_labelmap(label_map_path): - """Reads a label map and returns a category index. + """Reads a label map and returns a category index. - Args: - label_map_path: Path to `StringIntLabelMap` proto text file. + Args: + label_map_path: Path to `StringIntLabelMap` proto text file. - Returns: - A category index, which is a dictionary that maps integer ids to dicts - containing categories, e.g. - {1: {'id': 1, 'name': 'dog'}, 2: {'id': 2, 'name': 'cat'}, ...} - """ - label_map = load_labelmap(label_map_path) - max_num_classes = max(item.id for item in label_map.item) - categories = convert_label_map_to_categories(label_map, max_num_classes) - return create_category_index(categories) + Returns: + A category index, which is a dictionary that maps integer ids to dicts + containing categories, e.g. + {1: {'id': 1, 'name': 'dog'}, 2: {'id': 2, 'name': 'cat'}, ...} + """ + label_map = load_labelmap(label_map_path) + max_num_classes = max(item.id for item in label_map.item) + categories = convert_label_map_to_categories(label_map, max_num_classes) + return create_category_index(categories) def create_class_agnostic_category_index(): - """Creates a category index with a single `object` class.""" - return {1: {'id': 1, 'name': 'object'}} + """Creates a category index with a single `object` class.""" + return {1: {"id": 1, "name": "object"}} diff --git a/GSoC2018/table_detection/utils/visualization_utils.py b/GSoC2018/table_detection/utils/visualization_utils.py index 47a78672..7763ce8d 100644 --- a/GSoC2018/table_detection/utils/visualization_utils.py +++ b/GSoC2018/table_detection/utils/visualization_utils.py @@ -21,8 +21,11 @@ """ import collections import functools + # Set headless-friendly backend. -import matplotlib; matplotlib.use('Agg') # pylint: disable=multiple-statements +import matplotlib + +matplotlib.use("Agg") # pylint: disable=multiple-statements import matplotlib.pyplot as plt # pylint: disable=g-import-not-at-top import numpy as np import PIL.Image as Image @@ -38,491 +41,647 @@ _TITLE_LEFT_MARGIN = 10 _TITLE_TOP_MARGIN = 10 STANDARD_COLORS = [ - 'AliceBlue', 'Chartreuse', 'Aqua', 'Aquamarine', 'Azure', 'Beige', 'Bisque', - 'BlanchedAlmond', 'BlueViolet', 'BurlyWood', 'CadetBlue', 'AntiqueWhite', - 'Chocolate', 'Coral', 'CornflowerBlue', 'Cornsilk', 'Crimson', 'Cyan', - 'DarkCyan', 'DarkGoldenRod', 'DarkGrey', 'DarkKhaki', 'DarkOrange', - 'DarkOrchid', 'DarkSalmon', 'DarkSeaGreen', 'DarkTurquoise', 'DarkViolet', - 'DeepPink', 'DeepSkyBlue', 'DodgerBlue', 'FireBrick', 'FloralWhite', - 'ForestGreen', 'Fuchsia', 'Gainsboro', 'GhostWhite', 'Gold', 'GoldenRod', - 'Salmon', 'Tan', 'HoneyDew', 'HotPink', 'IndianRed', 'Ivory', 'Khaki', - 'Lavender', 'LavenderBlush', 'LawnGreen', 'LemonChiffon', 'LightBlue', - 'LightCoral', 'LightCyan', 'LightGoldenRodYellow', 'LightGray', 'LightGrey', - 'LightGreen', 'LightPink', 'LightSalmon', 'LightSeaGreen', 'LightSkyBlue', - 'LightSlateGray', 'LightSlateGrey', 'LightSteelBlue', 'LightYellow', 'Lime', - 'LimeGreen', 'Linen', 'Magenta', 'MediumAquaMarine', 'MediumOrchid', - 'MediumPurple', 'MediumSeaGreen', 'MediumSlateBlue', 'MediumSpringGreen', - 'MediumTurquoise', 'MediumVioletRed', 'MintCream', 'MistyRose', 'Moccasin', - 'NavajoWhite', 'OldLace', 'Olive', 'OliveDrab', 'Orange', 'OrangeRed', - 'Orchid', 'PaleGoldenRod', 'PaleGreen', 'PaleTurquoise', 'PaleVioletRed', - 'PapayaWhip', 'PeachPuff', 'Peru', 'Pink', 'Plum', 'PowderBlue', 'Purple', - 'Red', 'RosyBrown', 'RoyalBlue', 'SaddleBrown', 'Green', 'SandyBrown', - 'SeaGreen', 'SeaShell', 'Sienna', 'Silver', 'SkyBlue', 'SlateBlue', - 'SlateGray', 'SlateGrey', 'Snow', 'SpringGreen', 'SteelBlue', 'GreenYellow', - 'Teal', 'Thistle', 'Tomato', 'Turquoise', 'Violet', 'Wheat', 'White', - 'WhiteSmoke', 'Yellow', 'YellowGreen' + "AliceBlue", + "Chartreuse", + "Aqua", + "Aquamarine", + "Azure", + "Beige", + "Bisque", + "BlanchedAlmond", + "BlueViolet", + "BurlyWood", + "CadetBlue", + "AntiqueWhite", + "Chocolate", + "Coral", + "CornflowerBlue", + "Cornsilk", + "Crimson", + "Cyan", + "DarkCyan", + "DarkGoldenRod", + "DarkGrey", + "DarkKhaki", + "DarkOrange", + "DarkOrchid", + "DarkSalmon", + "DarkSeaGreen", + "DarkTurquoise", + "DarkViolet", + "DeepPink", + "DeepSkyBlue", + "DodgerBlue", + "FireBrick", + "FloralWhite", + "ForestGreen", + "Fuchsia", + "Gainsboro", + "GhostWhite", + "Gold", + "GoldenRod", + "Salmon", + "Tan", + "HoneyDew", + "HotPink", + "IndianRed", + "Ivory", + "Khaki", + "Lavender", + "LavenderBlush", + "LawnGreen", + "LemonChiffon", + "LightBlue", + "LightCoral", + "LightCyan", + "LightGoldenRodYellow", + "LightGray", + "LightGrey", + "LightGreen", + "LightPink", + "LightSalmon", + "LightSeaGreen", + "LightSkyBlue", + "LightSlateGray", + "LightSlateGrey", + "LightSteelBlue", + "LightYellow", + "Lime", + "LimeGreen", + "Linen", + "Magenta", + "MediumAquaMarine", + "MediumOrchid", + "MediumPurple", + "MediumSeaGreen", + "MediumSlateBlue", + "MediumSpringGreen", + "MediumTurquoise", + "MediumVioletRed", + "MintCream", + "MistyRose", + "Moccasin", + "NavajoWhite", + "OldLace", + "Olive", + "OliveDrab", + "Orange", + "OrangeRed", + "Orchid", + "PaleGoldenRod", + "PaleGreen", + "PaleTurquoise", + "PaleVioletRed", + "PapayaWhip", + "PeachPuff", + "Peru", + "Pink", + "Plum", + "PowderBlue", + "Purple", + "Red", + "RosyBrown", + "RoyalBlue", + "SaddleBrown", + "Green", + "SandyBrown", + "SeaGreen", + "SeaShell", + "Sienna", + "Silver", + "SkyBlue", + "SlateBlue", + "SlateGray", + "SlateGrey", + "Snow", + "SpringGreen", + "SteelBlue", + "GreenYellow", + "Teal", + "Thistle", + "Tomato", + "Turquoise", + "Violet", + "Wheat", + "White", + "WhiteSmoke", + "Yellow", + "YellowGreen", ] def save_image_array_as_png(image, output_path): - """Saves an image (represented as a numpy array) to PNG. + """Saves an image (represented as a numpy array) to PNG. - Args: - image: a numpy array with shape [height, width, 3]. - output_path: path to which image should be written. - """ - image_pil = Image.fromarray(np.uint8(image)).convert('RGB') - with tf.gfile.Open(output_path, 'w') as fid: - image_pil.save(fid, 'PNG') + Args: + image: a numpy array with shape [height, width, 3]. + output_path: path to which image should be written. + """ + image_pil = Image.fromarray(np.uint8(image)).convert("RGB") + with tf.gfile.Open(output_path, "w") as fid: + image_pil.save(fid, "PNG") def encode_image_array_as_png_str(image): - """Encodes a numpy array into a PNG string. - - Args: - image: a numpy array with shape [height, width, 3]. - - Returns: - PNG encoded image string. - """ - image_pil = Image.fromarray(np.uint8(image)) - output = six.BytesIO() - image_pil.save(output, format='PNG') - png_string = output.getvalue() - output.close() - return png_string - - -def draw_bounding_box_on_image_array(image, - ymin, - xmin, - ymax, - xmax, - color='red', - thickness=4, - display_str_list=(), - use_normalized_coordinates=True): - """Adds a bounding box to an image (numpy array). - - Bounding box coordinates can be specified in either absolute (pixel) or - normalized coordinates by setting the use_normalized_coordinates argument. - - Args: - image: a numpy array with shape [height, width, 3]. - ymin: ymin of bounding box. - xmin: xmin of bounding box. - ymax: ymax of bounding box. - xmax: xmax of bounding box. - color: color to draw bounding box. Default is red. - thickness: line thickness. Default value is 4. - display_str_list: list of strings to display in box - (each to be shown on its own line). - use_normalized_coordinates: If True (default), treat coordinates - ymin, xmin, ymax, xmax as relative to the image. Otherwise treat - coordinates as absolute. - """ - image_pil = Image.fromarray(np.uint8(image)).convert('RGB') - draw_bounding_box_on_image(image_pil, ymin, xmin, ymax, xmax, color, - thickness, display_str_list, - use_normalized_coordinates) - np.copyto(image, np.array(image_pil)) - - -def draw_bounding_box_on_image(image, - ymin, - xmin, - ymax, - xmax, - color='red', - thickness=4, - display_str_list=(), - use_normalized_coordinates=True): - """Adds a bounding box to an image. - - Bounding box coordinates can be specified in either absolute (pixel) or - normalized coordinates by setting the use_normalized_coordinates argument. - - Each string in display_str_list is displayed on a separate line above the - bounding box in black text on a rectangle filled with the input 'color'. - If the top of the bounding box extends to the edge of the image, the strings - are displayed below the bounding box. - - Args: - image: a PIL.Image object. - ymin: ymin of bounding box. - xmin: xmin of bounding box. - ymax: ymax of bounding box. - xmax: xmax of bounding box. - color: color to draw bounding box. Default is red. - thickness: line thickness. Default value is 4. - display_str_list: list of strings to display in box - (each to be shown on its own line). - use_normalized_coordinates: If True (default), treat coordinates - ymin, xmin, ymax, xmax as relative to the image. Otherwise treat - coordinates as absolute. - """ - draw = ImageDraw.Draw(image) - im_width, im_height = image.size - if use_normalized_coordinates: - (left, right, top, bottom) = (xmin * im_width, xmax * im_width, - ymin * im_height, ymax * im_height) - else: - (left, right, top, bottom) = (xmin, xmax, ymin, ymax) - draw.line([(left, top), (left, bottom), (right, bottom), - (right, top), (left, top)], width=thickness, fill=color) - try: - font = ImageFont.truetype('arial.ttf', 24) - except IOError: - font = ImageFont.load_default() - - # If the total height of the display strings added to the top of the bounding - # box exceeds the top of the image, stack the strings below the bounding box - # instead of above. - display_str_heights = [font.getsize(ds)[1] for ds in display_str_list] - # Each display_str has a top and bottom margin of 0.05x. - total_display_str_height = (1 + 2 * 0.05) * sum(display_str_heights) - - if top > total_display_str_height: - text_bottom = top - else: - text_bottom = bottom + total_display_str_height - # Reverse list and print from bottom to top. - for display_str in display_str_list[::-1]: - text_width, text_height = font.getsize(display_str) - margin = np.ceil(0.05 * text_height) - draw.rectangle( - [(left, text_bottom - text_height - 2 * margin), (left + text_width, - text_bottom)], - fill=color) - draw.text( - (left + margin, text_bottom - text_height - margin), - display_str, - fill='black', - font=font) - text_bottom -= text_height - 2 * margin - - -def draw_bounding_boxes_on_image_array(image, - boxes, - color='red', - thickness=4, - display_str_list_list=()): - """Draws bounding boxes on image (numpy array). - - Args: - image: a numpy array object. - boxes: a 2 dimensional numpy array of [N, 4]: (ymin, xmin, ymax, xmax). - The coordinates are in normalized format between [0, 1]. - color: color to draw bounding box. Default is red. - thickness: line thickness. Default value is 4. - display_str_list_list: list of list of strings. - a list of strings for each bounding box. - The reason to pass a list of strings for a - bounding box is that it might contain - multiple labels. - - Raises: - ValueError: if boxes is not a [N, 4] array - """ - image_pil = Image.fromarray(image) - draw_bounding_boxes_on_image(image_pil, boxes, color, thickness, - display_str_list_list) - np.copyto(image, np.array(image_pil)) - - -def draw_bounding_boxes_on_image(image, - boxes, - color='red', - thickness=4, - display_str_list_list=()): - """Draws bounding boxes on image. - - Args: - image: a PIL.Image object. - boxes: a 2 dimensional numpy array of [N, 4]: (ymin, xmin, ymax, xmax). - The coordinates are in normalized format between [0, 1]. - color: color to draw bounding box. Default is red. - thickness: line thickness. Default value is 4. - display_str_list_list: list of list of strings. - a list of strings for each bounding box. - The reason to pass a list of strings for a - bounding box is that it might contain - multiple labels. - - Raises: - ValueError: if boxes is not a [N, 4] array - """ - boxes_shape = boxes.shape - if not boxes_shape: - return - if len(boxes_shape) != 2 or boxes_shape[1] != 4: - raise ValueError('Input must be of size [N, 4]') - for i in range(boxes_shape[0]): - display_str_list = () - if display_str_list_list: - display_str_list = display_str_list_list[i] - draw_bounding_box_on_image(image, boxes[i, 0], boxes[i, 1], boxes[i, 2], - boxes[i, 3], color, thickness, display_str_list) + """Encodes a numpy array into a PNG string. + + Args: + image: a numpy array with shape [height, width, 3]. + + Returns: + PNG encoded image string. + """ + image_pil = Image.fromarray(np.uint8(image)) + output = six.BytesIO() + image_pil.save(output, format="PNG") + png_string = output.getvalue() + output.close() + return png_string + + +def draw_bounding_box_on_image_array( + image, + ymin, + xmin, + ymax, + xmax, + color="red", + thickness=4, + display_str_list=(), + use_normalized_coordinates=True, +): + """Adds a bounding box to an image (numpy array). + + Bounding box coordinates can be specified in either absolute (pixel) or + normalized coordinates by setting the use_normalized_coordinates argument. + + Args: + image: a numpy array with shape [height, width, 3]. + ymin: ymin of bounding box. + xmin: xmin of bounding box. + ymax: ymax of bounding box. + xmax: xmax of bounding box. + color: color to draw bounding box. Default is red. + thickness: line thickness. Default value is 4. + display_str_list: list of strings to display in box + (each to be shown on its own line). + use_normalized_coordinates: If True (default), treat coordinates + ymin, xmin, ymax, xmax as relative to the image. Otherwise treat + coordinates as absolute. + """ + image_pil = Image.fromarray(np.uint8(image)).convert("RGB") + draw_bounding_box_on_image( + image_pil, + ymin, + xmin, + ymax, + xmax, + color, + thickness, + display_str_list, + use_normalized_coordinates, + ) + np.copyto(image, np.array(image_pil)) + + +def draw_bounding_box_on_image( + image, + ymin, + xmin, + ymax, + xmax, + color="red", + thickness=4, + display_str_list=(), + use_normalized_coordinates=True, +): + """Adds a bounding box to an image. + + Bounding box coordinates can be specified in either absolute (pixel) or + normalized coordinates by setting the use_normalized_coordinates argument. + + Each string in display_str_list is displayed on a separate line above the + bounding box in black text on a rectangle filled with the input 'color'. + If the top of the bounding box extends to the edge of the image, the strings + are displayed below the bounding box. + + Args: + image: a PIL.Image object. + ymin: ymin of bounding box. + xmin: xmin of bounding box. + ymax: ymax of bounding box. + xmax: xmax of bounding box. + color: color to draw bounding box. Default is red. + thickness: line thickness. Default value is 4. + display_str_list: list of strings to display in box + (each to be shown on its own line). + use_normalized_coordinates: If True (default), treat coordinates + ymin, xmin, ymax, xmax as relative to the image. Otherwise treat + coordinates as absolute. + """ + draw = ImageDraw.Draw(image) + im_width, im_height = image.size + if use_normalized_coordinates: + (left, right, top, bottom) = ( + xmin * im_width, + xmax * im_width, + ymin * im_height, + ymax * im_height, + ) + else: + (left, right, top, bottom) = (xmin, xmax, ymin, ymax) + draw.line( + [(left, top), (left, bottom), (right, bottom), (right, top), (left, top)], + width=thickness, + fill=color, + ) + try: + font = ImageFont.truetype("arial.ttf", 24) + except IOError: + font = ImageFont.load_default() + + # If the total height of the display strings added to the top of the bounding + # box exceeds the top of the image, stack the strings below the bounding box + # instead of above. + display_str_heights = [font.getsize(ds)[1] for ds in display_str_list] + # Each display_str has a top and bottom margin of 0.05x. + total_display_str_height = (1 + 2 * 0.05) * sum(display_str_heights) + + if top > total_display_str_height: + text_bottom = top + else: + text_bottom = bottom + total_display_str_height + # Reverse list and print from bottom to top. + for display_str in display_str_list[::-1]: + text_width, text_height = font.getsize(display_str) + margin = np.ceil(0.05 * text_height) + draw.rectangle( + [ + (left, text_bottom - text_height - 2 * margin), + (left + text_width, text_bottom), + ], + fill=color, + ) + draw.text( + (left + margin, text_bottom - text_height - margin), + display_str, + fill="black", + font=font, + ) + text_bottom -= text_height - 2 * margin + + +def draw_bounding_boxes_on_image_array( + image, boxes, color="red", thickness=4, display_str_list_list=() +): + """Draws bounding boxes on image (numpy array). + + Args: + image: a numpy array object. + boxes: a 2 dimensional numpy array of [N, 4]: (ymin, xmin, ymax, xmax). + The coordinates are in normalized format between [0, 1]. + color: color to draw bounding box. Default is red. + thickness: line thickness. Default value is 4. + display_str_list_list: list of list of strings. + a list of strings for each bounding box. + The reason to pass a list of strings for a + bounding box is that it might contain + multiple labels. + + Raises: + ValueError: if boxes is not a [N, 4] array + """ + image_pil = Image.fromarray(image) + draw_bounding_boxes_on_image( + image_pil, boxes, color, thickness, display_str_list_list + ) + np.copyto(image, np.array(image_pil)) + + +def draw_bounding_boxes_on_image( + image, boxes, color="red", thickness=4, display_str_list_list=() +): + """Draws bounding boxes on image. + + Args: + image: a PIL.Image object. + boxes: a 2 dimensional numpy array of [N, 4]: (ymin, xmin, ymax, xmax). + The coordinates are in normalized format between [0, 1]. + color: color to draw bounding box. Default is red. + thickness: line thickness. Default value is 4. + display_str_list_list: list of list of strings. + a list of strings for each bounding box. + The reason to pass a list of strings for a + bounding box is that it might contain + multiple labels. + + Raises: + ValueError: if boxes is not a [N, 4] array + """ + boxes_shape = boxes.shape + if not boxes_shape: + return + if len(boxes_shape) != 2 or boxes_shape[1] != 4: + raise ValueError("Input must be of size [N, 4]") + for i in range(boxes_shape[0]): + display_str_list = () + if display_str_list_list: + display_str_list = display_str_list_list[i] + draw_bounding_box_on_image( + image, + boxes[i, 0], + boxes[i, 1], + boxes[i, 2], + boxes[i, 3], + color, + thickness, + display_str_list, + ) def _visualize_boxes(image, boxes, classes, scores, category_index, **kwargs): - return visualize_boxes_and_labels_on_image_array( - image, boxes, classes, scores, category_index=category_index, **kwargs) - - -def _visualize_boxes_and_masks(image, boxes, classes, scores, masks, - category_index, **kwargs): - return visualize_boxes_and_labels_on_image_array( - image, - boxes, - classes, - scores, - category_index=category_index, - instance_masks=masks, - **kwargs) - - -def _visualize_boxes_and_keypoints(image, boxes, classes, scores, keypoints, - category_index, **kwargs): - return visualize_boxes_and_labels_on_image_array( - image, - boxes, - classes, - scores, - category_index=category_index, - keypoints=keypoints, - **kwargs) + return visualize_boxes_and_labels_on_image_array( + image, boxes, classes, scores, category_index=category_index, **kwargs + ) -def _visualize_boxes_and_masks_and_keypoints( - image, boxes, classes, scores, masks, keypoints, category_index, **kwargs): - return visualize_boxes_and_labels_on_image_array( - image, - boxes, - classes, - scores, - category_index=category_index, - instance_masks=masks, - keypoints=keypoints, - **kwargs) - - -def draw_bounding_boxes_on_image_tensors(images, - boxes, - classes, - scores, - category_index, - instance_masks=None, - keypoints=None, - max_boxes_to_draw=20, - min_score_thresh=0.2): - """Draws bounding boxes, masks, and keypoints on batch of image tensors. - - Args: - images: A 4D uint8 image tensor of shape [N, H, W, C]. - boxes: [N, max_detections, 4] float32 tensor of detection boxes. - classes: [N, max_detections] int tensor of detection classes. Note that - classes are 1-indexed. - scores: [N, max_detections] float32 tensor of detection scores. - category_index: a dict that maps integer ids to category dicts. e.g. - {1: {1: 'dog'}, 2: {2: 'cat'}, ...} - instance_masks: A 4D uint8 tensor of shape [N, max_detection, H, W] with - instance masks. - keypoints: A 4D float32 tensor of shape [N, max_detection, num_keypoints, 2] - with keypoints. - max_boxes_to_draw: Maximum number of boxes to draw on an image. Default 20. - min_score_thresh: Minimum score threshold for visualization. Default 0.2. - - Returns: - 4D image tensor of type uint8, with boxes drawn on top. - """ - visualization_keyword_args = { - 'use_normalized_coordinates': True, - 'max_boxes_to_draw': max_boxes_to_draw, - 'min_score_thresh': min_score_thresh, - 'agnostic_mode': False, - 'line_thickness': 4 - } - - if instance_masks is not None and keypoints is None: - visualize_boxes_fn = functools.partial( - _visualize_boxes_and_masks, - category_index=category_index, - **visualization_keyword_args) - elems = [images, boxes, classes, scores, instance_masks] - elif instance_masks is None and keypoints is not None: - visualize_boxes_fn = functools.partial( - _visualize_boxes_and_keypoints, +def _visualize_boxes_and_masks( + image, boxes, classes, scores, masks, category_index, **kwargs +): + return visualize_boxes_and_labels_on_image_array( + image, + boxes, + classes, + scores, category_index=category_index, - **visualization_keyword_args) - elems = [images, boxes, classes, scores, keypoints] - elif instance_masks is not None and keypoints is not None: - visualize_boxes_fn = functools.partial( - _visualize_boxes_and_masks_and_keypoints, + instance_masks=masks, + **kwargs + ) + + +def _visualize_boxes_and_keypoints( + image, boxes, classes, scores, keypoints, category_index, **kwargs +): + return visualize_boxes_and_labels_on_image_array( + image, + boxes, + classes, + scores, category_index=category_index, - **visualization_keyword_args) - elems = [images, boxes, classes, scores, instance_masks, keypoints] - else: - visualize_boxes_fn = functools.partial( - _visualize_boxes, + keypoints=keypoints, + **kwargs + ) + + +def _visualize_boxes_and_masks_and_keypoints( + image, boxes, classes, scores, masks, keypoints, category_index, **kwargs +): + return visualize_boxes_and_labels_on_image_array( + image, + boxes, + classes, + scores, category_index=category_index, - **visualization_keyword_args) - elems = [images, boxes, classes, scores] - - def draw_boxes(image_and_detections): - """Draws boxes on image.""" - image_with_boxes = tf.py_func(visualize_boxes_fn, image_and_detections, - tf.uint8) - return image_with_boxes - - images = tf.map_fn(draw_boxes, elems, dtype=tf.uint8, back_prop=False) - return images - - -def draw_side_by_side_evaluation_image(eval_dict, - category_index, - max_boxes_to_draw=20, - min_score_thresh=0.2): - """Creates a side-by-side image with detections and groundtruth. - - Bounding boxes (and instance masks, if available) are visualized on both - subimages. - - Args: - eval_dict: The evaluation dictionary returned by - eval_util.result_dict_for_single_example(). - category_index: A category index (dictionary) produced from a labelmap. - max_boxes_to_draw: The maximum number of boxes to draw for detections. - min_score_thresh: The minimum score threshold for showing detections. - - Returns: - A [1, H, 2 * W, C] uint8 tensor. The subimage on the left corresponds to - detections, while the subimage on the right corresponds to groundtruth. - """ - detection_fields = fields.DetectionResultFields() - input_data_fields = fields.InputDataFields() - instance_masks = None - if detection_fields.detection_masks in eval_dict: - instance_masks = tf.cast( - tf.expand_dims(eval_dict[detection_fields.detection_masks], axis=0), - tf.uint8) - keypoints = None - if detection_fields.detection_keypoints in eval_dict: - keypoints = tf.expand_dims( - eval_dict[detection_fields.detection_keypoints], axis=0) - groundtruth_instance_masks = None - if input_data_fields.groundtruth_instance_masks in eval_dict: - groundtruth_instance_masks = tf.cast( + instance_masks=masks, + keypoints=keypoints, + **kwargs + ) + + +def draw_bounding_boxes_on_image_tensors( + images, + boxes, + classes, + scores, + category_index, + instance_masks=None, + keypoints=None, + max_boxes_to_draw=20, + min_score_thresh=0.2, +): + """Draws bounding boxes, masks, and keypoints on batch of image tensors. + + Args: + images: A 4D uint8 image tensor of shape [N, H, W, C]. + boxes: [N, max_detections, 4] float32 tensor of detection boxes. + classes: [N, max_detections] int tensor of detection classes. Note that + classes are 1-indexed. + scores: [N, max_detections] float32 tensor of detection scores. + category_index: a dict that maps integer ids to category dicts. e.g. + {1: {1: 'dog'}, 2: {2: 'cat'}, ...} + instance_masks: A 4D uint8 tensor of shape [N, max_detection, H, W] with + instance masks. + keypoints: A 4D float32 tensor of shape [N, max_detection, num_keypoints, 2] + with keypoints. + max_boxes_to_draw: Maximum number of boxes to draw on an image. Default 20. + min_score_thresh: Minimum score threshold for visualization. Default 0.2. + + Returns: + 4D image tensor of type uint8, with boxes drawn on top. + """ + visualization_keyword_args = { + "use_normalized_coordinates": True, + "max_boxes_to_draw": max_boxes_to_draw, + "min_score_thresh": min_score_thresh, + "agnostic_mode": False, + "line_thickness": 4, + } + + if instance_masks is not None and keypoints is None: + visualize_boxes_fn = functools.partial( + _visualize_boxes_and_masks, + category_index=category_index, + **visualization_keyword_args + ) + elems = [images, boxes, classes, scores, instance_masks] + elif instance_masks is None and keypoints is not None: + visualize_boxes_fn = functools.partial( + _visualize_boxes_and_keypoints, + category_index=category_index, + **visualization_keyword_args + ) + elems = [images, boxes, classes, scores, keypoints] + elif instance_masks is not None and keypoints is not None: + visualize_boxes_fn = functools.partial( + _visualize_boxes_and_masks_and_keypoints, + category_index=category_index, + **visualization_keyword_args + ) + elems = [images, boxes, classes, scores, instance_masks, keypoints] + else: + visualize_boxes_fn = functools.partial( + _visualize_boxes, + category_index=category_index, + **visualization_keyword_args + ) + elems = [images, boxes, classes, scores] + + def draw_boxes(image_and_detections): + """Draws boxes on image.""" + image_with_boxes = tf.py_func( + visualize_boxes_fn, image_and_detections, tf.uint8 + ) + return image_with_boxes + + images = tf.map_fn(draw_boxes, elems, dtype=tf.uint8, back_prop=False) + return images + + +def draw_side_by_side_evaluation_image( + eval_dict, category_index, max_boxes_to_draw=20, min_score_thresh=0.2 +): + """Creates a side-by-side image with detections and groundtruth. + + Bounding boxes (and instance masks, if available) are visualized on both + subimages. + + Args: + eval_dict: The evaluation dictionary returned by + eval_util.result_dict_for_single_example(). + category_index: A category index (dictionary) produced from a labelmap. + max_boxes_to_draw: The maximum number of boxes to draw for detections. + min_score_thresh: The minimum score threshold for showing detections. + + Returns: + A [1, H, 2 * W, C] uint8 tensor. The subimage on the left corresponds to + detections, while the subimage on the right corresponds to groundtruth. + """ + detection_fields = fields.DetectionResultFields() + input_data_fields = fields.InputDataFields() + instance_masks = None + if detection_fields.detection_masks in eval_dict: + instance_masks = tf.cast( + tf.expand_dims(eval_dict[detection_fields.detection_masks], axis=0), + tf.uint8, + ) + keypoints = None + if detection_fields.detection_keypoints in eval_dict: + keypoints = tf.expand_dims( + eval_dict[detection_fields.detection_keypoints], axis=0 + ) + groundtruth_instance_masks = None + if input_data_fields.groundtruth_instance_masks in eval_dict: + groundtruth_instance_masks = tf.cast( + tf.expand_dims( + eval_dict[input_data_fields.groundtruth_instance_masks], axis=0 + ), + tf.uint8, + ) + images_with_detections = draw_bounding_boxes_on_image_tensors( + eval_dict[input_data_fields.original_image], + tf.expand_dims(eval_dict[detection_fields.detection_boxes], axis=0), + tf.expand_dims(eval_dict[detection_fields.detection_classes], axis=0), + tf.expand_dims(eval_dict[detection_fields.detection_scores], axis=0), + category_index, + instance_masks=instance_masks, + keypoints=keypoints, + max_boxes_to_draw=max_boxes_to_draw, + min_score_thresh=min_score_thresh, + ) + images_with_groundtruth = draw_bounding_boxes_on_image_tensors( + eval_dict[input_data_fields.original_image], + tf.expand_dims(eval_dict[input_data_fields.groundtruth_boxes], axis=0), + tf.expand_dims(eval_dict[input_data_fields.groundtruth_classes], axis=0), tf.expand_dims( - eval_dict[input_data_fields.groundtruth_instance_masks], axis=0), - tf.uint8) - images_with_detections = draw_bounding_boxes_on_image_tensors( - eval_dict[input_data_fields.original_image], - tf.expand_dims(eval_dict[detection_fields.detection_boxes], axis=0), - tf.expand_dims(eval_dict[detection_fields.detection_classes], axis=0), - tf.expand_dims(eval_dict[detection_fields.detection_scores], axis=0), - category_index, - instance_masks=instance_masks, - keypoints=keypoints, - max_boxes_to_draw=max_boxes_to_draw, - min_score_thresh=min_score_thresh) - images_with_groundtruth = draw_bounding_boxes_on_image_tensors( - eval_dict[input_data_fields.original_image], - tf.expand_dims(eval_dict[input_data_fields.groundtruth_boxes], axis=0), - tf.expand_dims(eval_dict[input_data_fields.groundtruth_classes], axis=0), - tf.expand_dims( - tf.ones_like( - eval_dict[input_data_fields.groundtruth_classes], - dtype=tf.float32), - axis=0), - category_index, - instance_masks=groundtruth_instance_masks, - keypoints=None, - max_boxes_to_draw=None, - min_score_thresh=0.0) - return tf.concat([images_with_detections, images_with_groundtruth], axis=2) - - -def draw_keypoints_on_image_array(image, - keypoints, - color='red', - radius=2, - use_normalized_coordinates=True): - """Draws keypoints on an image (numpy array). - - Args: - image: a numpy array with shape [height, width, 3]. - keypoints: a numpy array with shape [num_keypoints, 2]. - color: color to draw the keypoints with. Default is red. - radius: keypoint radius. Default value is 2. - use_normalized_coordinates: if True (default), treat keypoint values as - relative to the image. Otherwise treat them as absolute. - """ - image_pil = Image.fromarray(np.uint8(image)).convert('RGB') - draw_keypoints_on_image(image_pil, keypoints, color, radius, - use_normalized_coordinates) - np.copyto(image, np.array(image_pil)) - - -def draw_keypoints_on_image(image, - keypoints, - color='red', - radius=2, - use_normalized_coordinates=True): - """Draws keypoints on an image. - - Args: - image: a PIL.Image object. - keypoints: a numpy array with shape [num_keypoints, 2]. - color: color to draw the keypoints with. Default is red. - radius: keypoint radius. Default value is 2. - use_normalized_coordinates: if True (default), treat keypoint values as - relative to the image. Otherwise treat them as absolute. - """ - draw = ImageDraw.Draw(image) - im_width, im_height = image.size - keypoints_x = [k[1] for k in keypoints] - keypoints_y = [k[0] for k in keypoints] - if use_normalized_coordinates: - keypoints_x = tuple([im_width * x for x in keypoints_x]) - keypoints_y = tuple([im_height * y for y in keypoints_y]) - for keypoint_x, keypoint_y in zip(keypoints_x, keypoints_y): - draw.ellipse([(keypoint_x - radius, keypoint_y - radius), - (keypoint_x + radius, keypoint_y + radius)], - outline=color, fill=color) - - -def draw_mask_on_image_array(image, mask, color='red', alpha=0.4): - """Draws mask on an image. - - Args: - image: uint8 numpy array with shape (img_height, img_height, 3) - mask: a uint8 numpy array of shape (img_height, img_height) with - values between either 0 or 1. - color: color to draw the keypoints with. Default is red. - alpha: transparency value between 0 and 1. (default: 0.4) - - Raises: - ValueError: On incorrect data type for image or masks. - """ - if image.dtype != np.uint8: - raise ValueError('`image` not of type np.uint8') - if mask.dtype != np.uint8: - raise ValueError('`mask` not of type np.uint8') - if np.any(np.logical_and(mask != 1, mask != 0)): - raise ValueError('`mask` elements should be in [0, 1]') - if image.shape[:2] != mask.shape: - raise ValueError('The image has spatial dimensions %s but the mask has ' - 'dimensions %s' % (image.shape[:2], mask.shape)) - rgb = ImageColor.getrgb(color) - pil_image = Image.fromarray(image) - - solid_color = np.expand_dims( - np.ones_like(mask), axis=2) * np.reshape(list(rgb), [1, 1, 3]) - pil_solid_color = Image.fromarray(np.uint8(solid_color)).convert('RGBA') - pil_mask = Image.fromarray(np.uint8(255.0*alpha*mask)).convert('L') - pil_image = Image.composite(pil_solid_color, pil_image, pil_mask) - np.copyto(image, np.array(pil_image.convert('RGB'))) + tf.ones_like( + eval_dict[input_data_fields.groundtruth_classes], dtype=tf.float32 + ), + axis=0, + ), + category_index, + instance_masks=groundtruth_instance_masks, + keypoints=None, + max_boxes_to_draw=None, + min_score_thresh=0.0, + ) + return tf.concat([images_with_detections, images_with_groundtruth], axis=2) + + +def draw_keypoints_on_image_array( + image, keypoints, color="red", radius=2, use_normalized_coordinates=True +): + """Draws keypoints on an image (numpy array). + + Args: + image: a numpy array with shape [height, width, 3]. + keypoints: a numpy array with shape [num_keypoints, 2]. + color: color to draw the keypoints with. Default is red. + radius: keypoint radius. Default value is 2. + use_normalized_coordinates: if True (default), treat keypoint values as + relative to the image. Otherwise treat them as absolute. + """ + image_pil = Image.fromarray(np.uint8(image)).convert("RGB") + draw_keypoints_on_image( + image_pil, keypoints, color, radius, use_normalized_coordinates + ) + np.copyto(image, np.array(image_pil)) + + +def draw_keypoints_on_image( + image, keypoints, color="red", radius=2, use_normalized_coordinates=True +): + """Draws keypoints on an image. + + Args: + image: a PIL.Image object. + keypoints: a numpy array with shape [num_keypoints, 2]. + color: color to draw the keypoints with. Default is red. + radius: keypoint radius. Default value is 2. + use_normalized_coordinates: if True (default), treat keypoint values as + relative to the image. Otherwise treat them as absolute. + """ + draw = ImageDraw.Draw(image) + im_width, im_height = image.size + keypoints_x = [k[1] for k in keypoints] + keypoints_y = [k[0] for k in keypoints] + if use_normalized_coordinates: + keypoints_x = tuple([im_width * x for x in keypoints_x]) + keypoints_y = tuple([im_height * y for y in keypoints_y]) + for keypoint_x, keypoint_y in zip(keypoints_x, keypoints_y): + draw.ellipse( + [ + (keypoint_x - radius, keypoint_y - radius), + (keypoint_x + radius, keypoint_y + radius), + ], + outline=color, + fill=color, + ) + + +def draw_mask_on_image_array(image, mask, color="red", alpha=0.4): + """Draws mask on an image. + + Args: + image: uint8 numpy array with shape (img_height, img_height, 3) + mask: a uint8 numpy array of shape (img_height, img_height) with + values between either 0 or 1. + color: color to draw the keypoints with. Default is red. + alpha: transparency value between 0 and 1. (default: 0.4) + + Raises: + ValueError: On incorrect data type for image or masks. + """ + if image.dtype != np.uint8: + raise ValueError("`image` not of type np.uint8") + if mask.dtype != np.uint8: + raise ValueError("`mask` not of type np.uint8") + if np.any(np.logical_and(mask != 1, mask != 0)): + raise ValueError("`mask` elements should be in [0, 1]") + if image.shape[:2] != mask.shape: + raise ValueError( + "The image has spatial dimensions %s but the mask has " + "dimensions %s" % (image.shape[:2], mask.shape) + ) + rgb = ImageColor.getrgb(color) + pil_image = Image.fromarray(image) + + solid_color = np.expand_dims(np.ones_like(mask), axis=2) * np.reshape( + list(rgb), [1, 1, 3] + ) + pil_solid_color = Image.fromarray(np.uint8(solid_color)).convert("RGBA") + pil_mask = Image.fromarray(np.uint8(255.0 * alpha * mask)).convert("L") + pil_image = Image.composite(pil_solid_color, pil_image, pil_mask) + np.copyto(image, np.array(pil_image.convert("RGB"))) def visualize_boxes_and_labels_on_image_array( @@ -536,185 +695,189 @@ def visualize_boxes_and_labels_on_image_array( keypoints=None, use_normalized_coordinates=False, max_boxes_to_draw=20, - min_score_thresh=.5, + min_score_thresh=0.5, agnostic_mode=False, line_thickness=4, - groundtruth_box_visualization_color='black', + groundtruth_box_visualization_color="black", skip_scores=False, - skip_labels=False): - """Overlay labeled boxes on an image with formatted scores and label names. - - This function groups boxes that correspond to the same location - and creates a display string for each detection and overlays these - on the image. Note that this function modifies the image in place, and returns - that same image. - - Args: - image: uint8 numpy array with shape (img_height, img_width, 3) - boxes: a numpy array of shape [N, 4] - classes: a numpy array of shape [N]. Note that class indices are 1-based, - and match the keys in the label map. - scores: a numpy array of shape [N] or None. If scores=None, then - this function assumes that the boxes to be plotted are groundtruth - boxes and plot all boxes as black with no classes or scores. - category_index: a dict containing category dictionaries (each holding - category index `id` and category name `name`) keyed by category indices. - instance_masks: a numpy array of shape [N, image_height, image_width] with - values ranging between 0 and 1, can be None. - instance_boundaries: a numpy array of shape [N, image_height, image_width] - with values ranging between 0 and 1, can be None. - keypoints: a numpy array of shape [N, num_keypoints, 2], can - be None - use_normalized_coordinates: whether boxes is to be interpreted as - normalized coordinates or not. - max_boxes_to_draw: maximum number of boxes to visualize. If None, draw - all boxes. - min_score_thresh: minimum score threshold for a box to be visualized - agnostic_mode: boolean (default: False) controlling whether to evaluate in - class-agnostic mode or not. This mode will display scores but ignore - classes. - line_thickness: integer (default: 4) controlling line width of the boxes. - groundtruth_box_visualization_color: box color for visualizing groundtruth - boxes - skip_scores: whether to skip score when drawing a single detection - skip_labels: whether to skip label when drawing a single detection - - Returns: - uint8 numpy array with shape (img_height, img_width, 3) with overlaid boxes. - """ - # Create a display string (and color) for every box location, group any boxes - # that correspond to the same location. - box_to_display_str_map = collections.defaultdict(list) - box_to_color_map = collections.defaultdict(str) - box_to_instance_masks_map = {} - box_to_instance_boundaries_map = {} - box_to_keypoints_map = collections.defaultdict(list) - if not max_boxes_to_draw: - max_boxes_to_draw = boxes.shape[0] - for i in range(min(max_boxes_to_draw, boxes.shape[0])): - if scores is None or scores[i] > min_score_thresh: - box = tuple(boxes[i].tolist()) - if instance_masks is not None: - box_to_instance_masks_map[box] = instance_masks[i] - if instance_boundaries is not None: - box_to_instance_boundaries_map[box] = instance_boundaries[i] - if keypoints is not None: - box_to_keypoints_map[box].extend(keypoints[i]) - if scores is None: - box_to_color_map[box] = groundtruth_box_visualization_color - else: - display_str = '' - if not skip_labels: - if not agnostic_mode: - if classes[i] in category_index.keys(): - class_name = category_index[classes[i]]['name'] + skip_labels=False, +): + """Overlay labeled boxes on an image with formatted scores and label names. + + This function groups boxes that correspond to the same location + and creates a display string for each detection and overlays these + on the image. Note that this function modifies the image in place, and returns + that same image. + + Args: + image: uint8 numpy array with shape (img_height, img_width, 3) + boxes: a numpy array of shape [N, 4] + classes: a numpy array of shape [N]. Note that class indices are 1-based, + and match the keys in the label map. + scores: a numpy array of shape [N] or None. If scores=None, then + this function assumes that the boxes to be plotted are groundtruth + boxes and plot all boxes as black with no classes or scores. + category_index: a dict containing category dictionaries (each holding + category index `id` and category name `name`) keyed by category indices. + instance_masks: a numpy array of shape [N, image_height, image_width] with + values ranging between 0 and 1, can be None. + instance_boundaries: a numpy array of shape [N, image_height, image_width] + with values ranging between 0 and 1, can be None. + keypoints: a numpy array of shape [N, num_keypoints, 2], can + be None + use_normalized_coordinates: whether boxes is to be interpreted as + normalized coordinates or not. + max_boxes_to_draw: maximum number of boxes to visualize. If None, draw + all boxes. + min_score_thresh: minimum score threshold for a box to be visualized + agnostic_mode: boolean (default: False) controlling whether to evaluate in + class-agnostic mode or not. This mode will display scores but ignore + classes. + line_thickness: integer (default: 4) controlling line width of the boxes. + groundtruth_box_visualization_color: box color for visualizing groundtruth + boxes + skip_scores: whether to skip score when drawing a single detection + skip_labels: whether to skip label when drawing a single detection + + Returns: + uint8 numpy array with shape (img_height, img_width, 3) with overlaid boxes. + """ + # Create a display string (and color) for every box location, group any boxes + # that correspond to the same location. + box_to_display_str_map = collections.defaultdict(list) + box_to_color_map = collections.defaultdict(str) + box_to_instance_masks_map = {} + box_to_instance_boundaries_map = {} + box_to_keypoints_map = collections.defaultdict(list) + if not max_boxes_to_draw: + max_boxes_to_draw = boxes.shape[0] + for i in range(min(max_boxes_to_draw, boxes.shape[0])): + if scores is None or scores[i] > min_score_thresh: + box = tuple(boxes[i].tolist()) + if instance_masks is not None: + box_to_instance_masks_map[box] = instance_masks[i] + if instance_boundaries is not None: + box_to_instance_boundaries_map[box] = instance_boundaries[i] + if keypoints is not None: + box_to_keypoints_map[box].extend(keypoints[i]) + if scores is None: + box_to_color_map[box] = groundtruth_box_visualization_color else: - class_name = 'N/A' - display_str = str(class_name) - if not skip_scores: - if not display_str: - display_str = '{}%'.format(int(100*scores[i])) - else: - display_str = '{}: {}%'.format(display_str, int(100*scores[i])) - box_to_display_str_map[box].append(display_str) - if agnostic_mode: - box_to_color_map[box] = 'DarkOrange' - else: - box_to_color_map[box] = STANDARD_COLORS[ - classes[i] % len(STANDARD_COLORS)] - - # Draw all boxes onto image. - for box, color in box_to_color_map.items(): - ymin, xmin, ymax, xmax = box - if instance_masks is not None: - draw_mask_on_image_array( - image, - box_to_instance_masks_map[box], - color=color - ) - if instance_boundaries is not None: - draw_mask_on_image_array( - image, - box_to_instance_boundaries_map[box], - color='red', - alpha=1.0 - ) - draw_bounding_box_on_image_array( - image, - ymin, - xmin, - ymax, - xmax, - color=color, - thickness=line_thickness, - display_str_list=box_to_display_str_map[box], - use_normalized_coordinates=use_normalized_coordinates) - if keypoints is not None: - draw_keypoints_on_image_array( - image, - box_to_keypoints_map[box], - color=color, - radius=line_thickness / 2, - use_normalized_coordinates=use_normalized_coordinates) + display_str = "" + if not skip_labels: + if not agnostic_mode: + if classes[i] in category_index.keys(): + class_name = category_index[classes[i]]["name"] + else: + class_name = "N/A" + display_str = str(class_name) + if not skip_scores: + if not display_str: + display_str = "{}%".format(int(100 * scores[i])) + else: + display_str = "{}: {}%".format( + display_str, int(100 * scores[i]) + ) + box_to_display_str_map[box].append(display_str) + if agnostic_mode: + box_to_color_map[box] = "DarkOrange" + else: + box_to_color_map[box] = STANDARD_COLORS[ + classes[i] % len(STANDARD_COLORS) + ] + + # Draw all boxes onto image. + for box, color in box_to_color_map.items(): + ymin, xmin, ymax, xmax = box + if instance_masks is not None: + draw_mask_on_image_array(image, box_to_instance_masks_map[box], color=color) + if instance_boundaries is not None: + draw_mask_on_image_array( + image, box_to_instance_boundaries_map[box], color="red", alpha=1.0 + ) + draw_bounding_box_on_image_array( + image, + ymin, + xmin, + ymax, + xmax, + color=color, + thickness=line_thickness, + display_str_list=box_to_display_str_map[box], + use_normalized_coordinates=use_normalized_coordinates, + ) + if keypoints is not None: + draw_keypoints_on_image_array( + image, + box_to_keypoints_map[box], + color=color, + radius=line_thickness / 2, + use_normalized_coordinates=use_normalized_coordinates, + ) - return image + return image def add_cdf_image_summary(values, name): - """Adds a tf.summary.image for a CDF plot of the values. - - Normalizes `values` such that they sum to 1, plots the cumulative distribution - function and creates a tf image summary. - - Args: - values: a 1-D float32 tensor containing the values. - name: name for the image summary. - """ - def cdf_plot(values): - """Numpy function to plot CDF.""" - normalized_values = values / np.sum(values) - sorted_values = np.sort(normalized_values) - cumulative_values = np.cumsum(sorted_values) - fraction_of_examples = (np.arange(cumulative_values.size, dtype=np.float32) - / cumulative_values.size) - fig = plt.figure(frameon=False) - ax = fig.add_subplot('111') - ax.plot(fraction_of_examples, cumulative_values) - ax.set_ylabel('cumulative normalized values') - ax.set_xlabel('fraction of examples') - fig.canvas.draw() - width, height = fig.get_size_inches() * fig.get_dpi() - image = np.fromstring(fig.canvas.tostring_rgb(), dtype='uint8').reshape( - 1, int(height), int(width), 3) - return image - cdf_plot = tf.py_func(cdf_plot, [values], tf.uint8) - tf.summary.image(name, cdf_plot) + """Adds a tf.summary.image for a CDF plot of the values. + + Normalizes `values` such that they sum to 1, plots the cumulative distribution + function and creates a tf image summary. + + Args: + values: a 1-D float32 tensor containing the values. + name: name for the image summary. + """ + + def cdf_plot(values): + """Numpy function to plot CDF.""" + normalized_values = values / np.sum(values) + sorted_values = np.sort(normalized_values) + cumulative_values = np.cumsum(sorted_values) + fraction_of_examples = ( + np.arange(cumulative_values.size, dtype=np.float32) / cumulative_values.size + ) + fig = plt.figure(frameon=False) + ax = fig.add_subplot("111") + ax.plot(fraction_of_examples, cumulative_values) + ax.set_ylabel("cumulative normalized values") + ax.set_xlabel("fraction of examples") + fig.canvas.draw() + width, height = fig.get_size_inches() * fig.get_dpi() + image = np.fromstring(fig.canvas.tostring_rgb(), dtype="uint8").reshape( + 1, int(height), int(width), 3 + ) + return image + + cdf_plot = tf.py_func(cdf_plot, [values], tf.uint8) + tf.summary.image(name, cdf_plot) def add_hist_image_summary(values, bins, name): - """Adds a tf.summary.image for a histogram plot of the values. - - Plots the histogram of values and creates a tf image summary. - - Args: - values: a 1-D float32 tensor containing the values. - bins: bin edges which will be directly passed to np.histogram. - name: name for the image summary. - """ - - def hist_plot(values, bins): - """Numpy function to plot hist.""" - fig = plt.figure(frameon=False) - ax = fig.add_subplot('111') - y, x = np.histogram(values, bins=bins) - ax.plot(x[:-1], y) - ax.set_ylabel('count') - ax.set_xlabel('value') - fig.canvas.draw() - width, height = fig.get_size_inches() * fig.get_dpi() - image = np.fromstring( - fig.canvas.tostring_rgb(), dtype='uint8').reshape( - 1, int(height), int(width), 3) - return image - hist_plot = tf.py_func(hist_plot, [values, bins], tf.uint8) - tf.summary.image(name, hist_plot) + """Adds a tf.summary.image for a histogram plot of the values. + + Plots the histogram of values and creates a tf image summary. + + Args: + values: a 1-D float32 tensor containing the values. + bins: bin edges which will be directly passed to np.histogram. + name: name for the image summary. + """ + + def hist_plot(values, bins): + """Numpy function to plot hist.""" + fig = plt.figure(frameon=False) + ax = fig.add_subplot("111") + y, x = np.histogram(values, bins=bins) + ax.plot(x[:-1], y) + ax.set_ylabel("count") + ax.set_xlabel("value") + fig.canvas.draw() + width, height = fig.get_size_inches() * fig.get_dpi() + image = np.fromstring(fig.canvas.tostring_rgb(), dtype="uint8").reshape( + 1, int(height), int(width), 3 + ) + return image + + hist_plot = tf.py_func(hist_plot, [values, bins], tf.uint8) + tf.summary.image(name, hist_plot) diff --git a/ai-emlyon/eml/datanavig.py b/ai-emlyon/eml/datanavig.py index 61d06476..8c111d16 100644 --- a/ai-emlyon/eml/datanavig.py +++ b/ai-emlyon/eml/datanavig.py @@ -1,17 +1,18 @@ -#----------------------------IMPORTS---------------------------- +# ----------------------------IMPORTS---------------------------- import pandas as pd import numpy as np -#----------------------------FUNCTIONS---------------------------- +# ----------------------------FUNCTIONS---------------------------- -#Drop cols based on a treshold, as int (abs count) or float (proportion) -def drop_cols (df, tresh=0.5, inplace=True, print_cols=True): - """ + +# Drop cols based on a treshold, as int (abs count) or float (proportion) +def drop_cols(df, tresh=0.5, inplace=True, print_cols=True): + """ Drop all columns with the count of non-null values less than the treshold specified. - Return a dictionnary with keys equal to the names of dropped columns, + Return a dictionnary with keys equal to the names of dropped columns, and values equal to the type of dropped columns. - + Parameters ---------- df : Pandas DataFrame @@ -30,7 +31,7 @@ def drop_cols (df, tresh=0.5, inplace=True, print_cols=True): Return ------ - Return a dictionnary with keys equal to the names of dropped columns, + Return a dictionnary with keys equal to the names of dropped columns, and values equal to the type of dropped columns. Example: {'energy_100g' : float, 'stores' : object} @@ -38,380 +39,413 @@ def drop_cols (df, tresh=0.5, inplace=True, print_cols=True): """ cols_dropped = {} cols = df.columns.to_list() - if isinstance(tresh, float): #if tresh is float - treshold = int(len(df) * tresh) - else: #if tresh is int + if isinstance(tresh, float): # if tresh is float + treshold = int(len(df) * tresh) + else: # if tresh is int treshold = tresh - for col in cols : #loop over columns - coltype = str(df[col].dtype) #save col dtype - col_count = df[col].count() #save col count (non-null values) - if col_count < treshold : - df.drop([col], axis=1, inplace=inplace) #drop the col - to_update = {col:[coltype, col_count]} - cols_dropped.update(to_update) #append the dict - if print_cols: #if print_cols is True + for col in cols: # loop over columns + coltype = str(df[col].dtype) # save col dtype + col_count = df[col].count() # save col count (non-null values) + if col_count < treshold: + df.drop([col], axis=1, inplace=inplace) # drop the col + to_update = {col: [coltype, col_count]} + cols_dropped.update(to_update) # append the dict + if print_cols: # if print_cols is True print(f"Column removed : {col} --- dtype : {coltype}") return cols_dropped -#Drop cols in the off_cols_dict + +# Drop cols in the off_cols_dict def remove_cols_dict(cols_dict, cols_to_drop): if isinstance(cols_to_drop, dict): - cols_to_drop = [col for col in cols_to_drop.keys()] #extract columns names in a list - for group in cols_dict.keys(): #loop over the cols dict - cols_dict[group] = [col for col in cols_dict[group] if col not in cols_to_drop] #keep only cols names not in cols_to_drop + cols_to_drop = [ + col for col in cols_to_drop.keys() + ] # extract columns names in a list + for group in cols_dict.keys(): # loop over the cols dict + cols_dict[group] = [ + col for col in cols_dict[group] if col not in cols_to_drop + ] # keep only cols names not in cols_to_drop + -#----------------------------DATA---------------------------- +# ----------------------------DATA---------------------------- -ppns2_list = ['appetizers', - 'artificially sweetened beverages', - 'biscuits and cakes', - 'bread', - 'breakfast cereals', - 'cereals', - 'cheese', - 'chocolate products', - 'dairy desserts', - 'dressings and sauces', - 'dried fruits', - 'eggs', - 'fats', - 'fish and seafood', - 'fruit juices', - 'fruit nectars', - 'fruits', - 'ice cream', - 'legumes', - 'meat', - 'milk and yogurt', - 'nuts', - 'offals', - 'one dish meals', - 'pastries', - 'pizza pies and quiche', - 'plant based milk substitutes', - 'potatoes', - 'processed meat', - 'salty and fatty products', - 'sandwiches', - 'soups', - 'sweetened beverages', - 'sweets', - 'teas and herbal teas and coffees', - 'unsweetened beverages', - 'vegetables', - 'waters and flavored waters'] +ppns2_list = [ + "appetizers", + "artificially sweetened beverages", + "biscuits and cakes", + "bread", + "breakfast cereals", + "cereals", + "cheese", + "chocolate products", + "dairy desserts", + "dressings and sauces", + "dried fruits", + "eggs", + "fats", + "fish and seafood", + "fruit juices", + "fruit nectars", + "fruits", + "ice cream", + "legumes", + "meat", + "milk and yogurt", + "nuts", + "offals", + "one dish meals", + "pastries", + "pizza pies and quiche", + "plant based milk substitutes", + "potatoes", + "processed meat", + "salty and fatty products", + "sandwiches", + "soups", + "sweetened beverages", + "sweets", + "teas and herbal teas and coffees", + "unsweetened beverages", + "vegetables", + "waters and flavored waters", +] -off_columns_dict = {'robotoff':['robotoff_countries', 'robotoff_brands', 'robotoff_value_tag', -'robotoff_data.lang', 'robotoff_data.model', 'robotoff_data.confidence'], -'meta':['code', - 'url', - 'creator', - 'created_t', - 'created_datetime', - 'last_modified_t', - 'last_modified_datetime', - 'emb_codes', - 'emb_codes_tags', - 'states', - 'states_tags', - 'states_en'], - 'image':['image_url', - 'image_small_url', - 'image_ingredients_url', - 'image_ingredients_small_url', - 'image_nutrition_url', - 'image_nutrition_small_url'], - 'infopdt':[ 'product_name', - 'abbreviated_product_name', - 'generic_name', - 'quantity', - 'packaging', - 'packaging_tags', - 'packaging_text', - 'brands', - 'brands_tags', - 'stores', - 'serving_size', - 'serving_quantity', - 'nutriscore_score', - 'nutriscore_grade', - 'nova_group', - 'brand_owner'], - 'ingredients':[ 'labels', - 'labels_tags', - 'labels_en', - 'ingredients_text', - 'allergens', - 'traces', - 'traces_tags', - 'traces_en', - 'additives_n', - 'additives_tags', - 'additives_en', - 'ingredients_from_palm_oil_n', - 'ingredients_from_palm_oil_tags', - 'ingredients_that_may_be_from_palm_oil_n', - 'ingredients_that_may_be_from_palm_oil_tags', - 'nutriscore_score', - 'nutriscore_grade', - 'nova_group'], - 'cats':[ 'main_category', - 'main_category_en', - 'pnns_groups_1', - 'pnns_groups_2', - 'categories', - 'categories_tags', - 'categories_en'], - 'geo':['origins', - 'origins_tags', - 'origins_en', - 'manufacturing_places', - 'manufacturing_places_tags', - 'cities_tags', - 'countries', - 'countries_tags', - 'countries_en', - 'first_packaging_code_geo'], - 'empty':['cities', - 'allergens_en', - 'no_nutriments', - 'ingredients_from_palm_oil', - 'ingredients_that_may_be_from_palm_oil'], - 'dummy':['pnns_groups_1', - 'pnns_groups_2', - 'nutriscore_grade'], - 'tags':['categories_tags', - 'manufacturing_places_tags', - 'labels_tags', - 'emb_codes_tags', - 'cities_tags', - 'countries_tags'], - '100g':['energy-kj_100g', - 'energy-kcal_100g', - 'energy_100g', - 'energy-from-fat_100g', - 'fat_100g', - 'saturated-fat_100g', - '-butyric-acid_100g', - '-caproic-acid_100g', - '-caprylic-acid_100g', - '-capric-acid_100g', - '-lauric-acid_100g', - '-myristic-acid_100g', - '-palmitic-acid_100g', - '-stearic-acid_100g', - '-arachidic-acid_100g', - '-behenic-acid_100g', - '-lignoceric-acid_100g', - '-cerotic-acid_100g', - '-montanic-acid_100g', - '-melissic-acid_100g', - 'monounsaturated-fat_100g', - 'polyunsaturated-fat_100g', - 'omega-3-fat_100g', - '-alpha-linolenic-acid_100g', - '-eicosapentaenoic-acid_100g', - '-docosahexaenoic-acid_100g', - 'omega-6-fat_100g', - '-linoleic-acid_100g', - '-arachidonic-acid_100g', - '-gamma-linolenic-acid_100g', - '-dihomo-gamma-linolenic-acid_100g', - 'omega-9-fat_100g', - '-oleic-acid_100g', - '-elaidic-acid_100g', - '-gondoic-acid_100g', - '-mead-acid_100g', - '-erucic-acid_100g', - '-nervonic-acid_100g', - 'trans-fat_100g', - 'cholesterol_100g', - 'carbohydrates_100g', - 'sugars_100g', - '-sucrose_100g', - '-glucose_100g', - '-fructose_100g', - '-lactose_100g', - '-maltose_100g', - '-maltodextrins_100g', - 'starch_100g', - 'polyols_100g', - 'fiber_100g', - '-soluble-fiber_100g', - '-insoluble-fiber_100g', - 'proteins_100g', - 'casein_100g', - 'serum-proteins_100g', - 'nucleotides_100g', - 'salt_100g', - 'sodium_100g', - 'alcohol_100g', - 'vitamin-a_100g', - 'beta-carotene_100g', - 'vitamin-d_100g', - 'vitamin-e_100g', - 'vitamin-k_100g', - 'vitamin-c_100g', - 'vitamin-b1_100g', - 'vitamin-b2_100g', - 'vitamin-pp_100g', - 'vitamin-b6_100g', - 'vitamin-b9_100g', - 'folates_100g', - 'vitamin-b12_100g', - 'biotin_100g', - 'pantothenic-acid_100g', - 'silica_100g', - 'bicarbonate_100g', - 'potassium_100g', - 'chloride_100g', - 'calcium_100g', - 'phosphorus_100g', - 'iron_100g', - 'magnesium_100g', - 'zinc_100g', - 'copper_100g', - 'manganese_100g', - 'fluoride_100g', - 'selenium_100g', - 'chromium_100g', - 'molybdenum_100g', - 'iodine_100g', - 'caffeine_100g', - 'taurine_100g', - 'ph_100g', - 'fruits-vegetables-nuts_100g', - 'fruits-vegetables-nuts-dried_100g', - 'fruits-vegetables-nuts-estimate_100g', - 'collagen-meat-protein-ratio_100g', - 'cocoa_100g', - 'chlorophyl_100g', - 'carbon-footprint_100g', - 'carbon-footprint-from-meat-or-fish_100g', - 'nutrition-score-fr_100g', - 'nutrition-score-uk_100g', - 'glycemic-index_100g', - 'water-hardness_100g', - 'choline_100g', - 'phylloquinone_100g', - 'beta-glucan_100g', - 'inositol_100g', - 'carnitine_100g'], - 'numeric':['serving_quantity', -'additives_n', -'ingredients_from_palm_oil_n', -'ingredients_that_may_be_from_palm_oil_n', -'nutriscore_score', -'nova_group', -'energy-kj_100g', - 'energy-kcal_100g', - 'energy_100g', - 'energy-from-fat_100g', - 'fat_100g', - 'saturated-fat_100g', - '-butyric-acid_100g', - '-caproic-acid_100g', - '-caprylic-acid_100g', - '-capric-acid_100g', - '-lauric-acid_100g', - '-myristic-acid_100g', - '-palmitic-acid_100g', - '-stearic-acid_100g', - '-arachidic-acid_100g', - '-behenic-acid_100g', - '-lignoceric-acid_100g', - '-cerotic-acid_100g', - '-montanic-acid_100g', - '-melissic-acid_100g', - 'monounsaturated-fat_100g', - 'polyunsaturated-fat_100g', - 'omega-3-fat_100g', - '-alpha-linolenic-acid_100g', - '-eicosapentaenoic-acid_100g', - '-docosahexaenoic-acid_100g', - 'omega-6-fat_100g', - '-linoleic-acid_100g', - '-arachidonic-acid_100g', - '-gamma-linolenic-acid_100g', - '-dihomo-gamma-linolenic-acid_100g', - 'omega-9-fat_100g', - '-oleic-acid_100g', - '-elaidic-acid_100g', - '-gondoic-acid_100g', - '-mead-acid_100g', - '-erucic-acid_100g', - '-nervonic-acid_100g', - 'trans-fat_100g', - 'cholesterol_100g', - 'carbohydrates_100g', - 'sugars_100g', - '-sucrose_100g', - '-glucose_100g', - '-fructose_100g', - '-lactose_100g', - '-maltose_100g', - '-maltodextrins_100g', - 'starch_100g', - 'polyols_100g', - 'fiber_100g', - '-soluble-fiber_100g', - '-insoluble-fiber_100g', - 'proteins_100g', - 'casein_100g', - 'serum-proteins_100g', - 'nucleotides_100g', - 'salt_100g', - 'sodium_100g', - 'alcohol_100g', - 'vitamin-a_100g', - 'beta-carotene_100g', - 'vitamin-d_100g', - 'vitamin-e_100g', - 'vitamin-k_100g', - 'vitamin-c_100g', - 'vitamin-b1_100g', - 'vitamin-b2_100g', - 'vitamin-pp_100g', - 'vitamin-b6_100g', - 'vitamin-b9_100g', - 'folates_100g', - 'vitamin-b12_100g', - 'biotin_100g', - 'pantothenic-acid_100g', - 'silica_100g', - 'bicarbonate_100g', - 'potassium_100g', - 'chloride_100g', - 'calcium_100g', - 'phosphorus_100g', - 'iron_100g', - 'magnesium_100g', - 'zinc_100g', - 'copper_100g', - 'manganese_100g', - 'fluoride_100g', - 'selenium_100g', - 'chromium_100g', - 'molybdenum_100g', - 'iodine_100g', - 'caffeine_100g', - 'taurine_100g', - 'ph_100g', - 'fruits-vegetables-nuts_100g', - 'fruits-vegetables-nuts-dried_100g', - 'fruits-vegetables-nuts-estimate_100g', - 'collagen-meat-protein-ratio_100g', - 'cocoa_100g', - 'chlorophyl_100g', - 'carbon-footprint_100g', - 'carbon-footprint-from-meat-or-fish_100g', - 'nutrition-score-fr_100g', - 'nutrition-score-uk_100g', - 'glycemic-index_100g', - 'water-hardness_100g', - 'choline_100g', - 'phylloquinone_100g', - 'beta-glucan_100g', - 'inositol_100g', - 'carnitine_100g'] - } \ No newline at end of file +off_columns_dict = { + "robotoff": [ + "robotoff_countries", + "robotoff_brands", + "robotoff_value_tag", + "robotoff_data.lang", + "robotoff_data.model", + "robotoff_data.confidence", + ], + "meta": [ + "code", + "url", + "creator", + "created_t", + "created_datetime", + "last_modified_t", + "last_modified_datetime", + "emb_codes", + "emb_codes_tags", + "states", + "states_tags", + "states_en", + ], + "image": [ + "image_url", + "image_small_url", + "image_ingredients_url", + "image_ingredients_small_url", + "image_nutrition_url", + "image_nutrition_small_url", + ], + "infopdt": [ + "product_name", + "abbreviated_product_name", + "generic_name", + "quantity", + "packaging", + "packaging_tags", + "packaging_text", + "brands", + "brands_tags", + "stores", + "serving_size", + "serving_quantity", + "nutriscore_score", + "nutriscore_grade", + "nova_group", + "brand_owner", + ], + "ingredients": [ + "labels", + "labels_tags", + "labels_en", + "ingredients_text", + "allergens", + "traces", + "traces_tags", + "traces_en", + "additives_n", + "additives_tags", + "additives_en", + "ingredients_from_palm_oil_n", + "ingredients_from_palm_oil_tags", + "ingredients_that_may_be_from_palm_oil_n", + "ingredients_that_may_be_from_palm_oil_tags", + "nutriscore_score", + "nutriscore_grade", + "nova_group", + ], + "cats": [ + "main_category", + "main_category_en", + "pnns_groups_1", + "pnns_groups_2", + "categories", + "categories_tags", + "categories_en", + ], + "geo": [ + "origins", + "origins_tags", + "origins_en", + "manufacturing_places", + "manufacturing_places_tags", + "cities_tags", + "countries", + "countries_tags", + "countries_en", + "first_packaging_code_geo", + ], + "empty": [ + "cities", + "allergens_en", + "no_nutriments", + "ingredients_from_palm_oil", + "ingredients_that_may_be_from_palm_oil", + ], + "dummy": ["pnns_groups_1", "pnns_groups_2", "nutriscore_grade"], + "tags": [ + "categories_tags", + "manufacturing_places_tags", + "labels_tags", + "emb_codes_tags", + "cities_tags", + "countries_tags", + ], + "100g": [ + "energy-kj_100g", + "energy-kcal_100g", + "energy_100g", + "energy-from-fat_100g", + "fat_100g", + "saturated-fat_100g", + "-butyric-acid_100g", + "-caproic-acid_100g", + "-caprylic-acid_100g", + "-capric-acid_100g", + "-lauric-acid_100g", + "-myristic-acid_100g", + "-palmitic-acid_100g", + "-stearic-acid_100g", + "-arachidic-acid_100g", + "-behenic-acid_100g", + "-lignoceric-acid_100g", + "-cerotic-acid_100g", + "-montanic-acid_100g", + "-melissic-acid_100g", + "monounsaturated-fat_100g", + "polyunsaturated-fat_100g", + "omega-3-fat_100g", + "-alpha-linolenic-acid_100g", + "-eicosapentaenoic-acid_100g", + "-docosahexaenoic-acid_100g", + "omega-6-fat_100g", + "-linoleic-acid_100g", + "-arachidonic-acid_100g", + "-gamma-linolenic-acid_100g", + "-dihomo-gamma-linolenic-acid_100g", + "omega-9-fat_100g", + "-oleic-acid_100g", + "-elaidic-acid_100g", + "-gondoic-acid_100g", + "-mead-acid_100g", + "-erucic-acid_100g", + "-nervonic-acid_100g", + "trans-fat_100g", + "cholesterol_100g", + "carbohydrates_100g", + "sugars_100g", + "-sucrose_100g", + "-glucose_100g", + "-fructose_100g", + "-lactose_100g", + "-maltose_100g", + "-maltodextrins_100g", + "starch_100g", + "polyols_100g", + "fiber_100g", + "-soluble-fiber_100g", + "-insoluble-fiber_100g", + "proteins_100g", + "casein_100g", + "serum-proteins_100g", + "nucleotides_100g", + "salt_100g", + "sodium_100g", + "alcohol_100g", + "vitamin-a_100g", + "beta-carotene_100g", + "vitamin-d_100g", + "vitamin-e_100g", + "vitamin-k_100g", + "vitamin-c_100g", + "vitamin-b1_100g", + "vitamin-b2_100g", + "vitamin-pp_100g", + "vitamin-b6_100g", + "vitamin-b9_100g", + "folates_100g", + "vitamin-b12_100g", + "biotin_100g", + "pantothenic-acid_100g", + "silica_100g", + "bicarbonate_100g", + "potassium_100g", + "chloride_100g", + "calcium_100g", + "phosphorus_100g", + "iron_100g", + "magnesium_100g", + "zinc_100g", + "copper_100g", + "manganese_100g", + "fluoride_100g", + "selenium_100g", + "chromium_100g", + "molybdenum_100g", + "iodine_100g", + "caffeine_100g", + "taurine_100g", + "ph_100g", + "fruits-vegetables-nuts_100g", + "fruits-vegetables-nuts-dried_100g", + "fruits-vegetables-nuts-estimate_100g", + "collagen-meat-protein-ratio_100g", + "cocoa_100g", + "chlorophyl_100g", + "carbon-footprint_100g", + "carbon-footprint-from-meat-or-fish_100g", + "nutrition-score-fr_100g", + "nutrition-score-uk_100g", + "glycemic-index_100g", + "water-hardness_100g", + "choline_100g", + "phylloquinone_100g", + "beta-glucan_100g", + "inositol_100g", + "carnitine_100g", + ], + "numeric": [ + "serving_quantity", + "additives_n", + "ingredients_from_palm_oil_n", + "ingredients_that_may_be_from_palm_oil_n", + "nutriscore_score", + "nova_group", + "energy-kj_100g", + "energy-kcal_100g", + "energy_100g", + "energy-from-fat_100g", + "fat_100g", + "saturated-fat_100g", + "-butyric-acid_100g", + "-caproic-acid_100g", + "-caprylic-acid_100g", + "-capric-acid_100g", + "-lauric-acid_100g", + "-myristic-acid_100g", + "-palmitic-acid_100g", + "-stearic-acid_100g", + "-arachidic-acid_100g", + "-behenic-acid_100g", + "-lignoceric-acid_100g", + "-cerotic-acid_100g", + "-montanic-acid_100g", + "-melissic-acid_100g", + "monounsaturated-fat_100g", + "polyunsaturated-fat_100g", + "omega-3-fat_100g", + "-alpha-linolenic-acid_100g", + "-eicosapentaenoic-acid_100g", + "-docosahexaenoic-acid_100g", + "omega-6-fat_100g", + "-linoleic-acid_100g", + "-arachidonic-acid_100g", + "-gamma-linolenic-acid_100g", + "-dihomo-gamma-linolenic-acid_100g", + "omega-9-fat_100g", + "-oleic-acid_100g", + "-elaidic-acid_100g", + "-gondoic-acid_100g", + "-mead-acid_100g", + "-erucic-acid_100g", + "-nervonic-acid_100g", + "trans-fat_100g", + "cholesterol_100g", + "carbohydrates_100g", + "sugars_100g", + "-sucrose_100g", + "-glucose_100g", + "-fructose_100g", + "-lactose_100g", + "-maltose_100g", + "-maltodextrins_100g", + "starch_100g", + "polyols_100g", + "fiber_100g", + "-soluble-fiber_100g", + "-insoluble-fiber_100g", + "proteins_100g", + "casein_100g", + "serum-proteins_100g", + "nucleotides_100g", + "salt_100g", + "sodium_100g", + "alcohol_100g", + "vitamin-a_100g", + "beta-carotene_100g", + "vitamin-d_100g", + "vitamin-e_100g", + "vitamin-k_100g", + "vitamin-c_100g", + "vitamin-b1_100g", + "vitamin-b2_100g", + "vitamin-pp_100g", + "vitamin-b6_100g", + "vitamin-b9_100g", + "folates_100g", + "vitamin-b12_100g", + "biotin_100g", + "pantothenic-acid_100g", + "silica_100g", + "bicarbonate_100g", + "potassium_100g", + "chloride_100g", + "calcium_100g", + "phosphorus_100g", + "iron_100g", + "magnesium_100g", + "zinc_100g", + "copper_100g", + "manganese_100g", + "fluoride_100g", + "selenium_100g", + "chromium_100g", + "molybdenum_100g", + "iodine_100g", + "caffeine_100g", + "taurine_100g", + "ph_100g", + "fruits-vegetables-nuts_100g", + "fruits-vegetables-nuts-dried_100g", + "fruits-vegetables-nuts-estimate_100g", + "collagen-meat-protein-ratio_100g", + "cocoa_100g", + "chlorophyl_100g", + "carbon-footprint_100g", + "carbon-footprint-from-meat-or-fish_100g", + "nutrition-score-fr_100g", + "nutrition-score-uk_100g", + "glycemic-index_100g", + "water-hardness_100g", + "choline_100g", + "phylloquinone_100g", + "beta-glucan_100g", + "inositol_100g", + "carnitine_100g", + ], +} diff --git a/ai-emlyon/eml/eda.py b/ai-emlyon/eml/eda.py index 79959b02..599c1765 100644 --- a/ai-emlyon/eml/eda.py +++ b/ai-emlyon/eml/eda.py @@ -1,12 +1,15 @@ -import pandas as pd +import pandas as pd import seaborn as sns import matplotlib.pyplot as plt -import numpy as np +import numpy as np + def plot_correlations(df): corr_matrix = df.corr() mask = np.zeros_like(corr_matrix, dtype=np.bool) - mask[np.triu_indices_from(mask)]= True - plt.figure(figsize=(25,15)) - heatmap = sns.heatmap(corr_matrix,vmin=-1, vmax=1, annot=True, cmap='RdBu', mask=mask) - heatmap.set_title("Correlation Heatmap") \ No newline at end of file + mask[np.triu_indices_from(mask)] = True + plt.figure(figsize=(25, 15)) + heatmap = sns.heatmap( + corr_matrix, vmin=-1, vmax=1, annot=True, cmap="RdBu", mask=mask + ) + heatmap.set_title("Correlation Heatmap") diff --git a/ai-emlyon/eml/model_eval.py b/ai-emlyon/eml/model_eval.py index edbd68aa..dacf2e31 100644 --- a/ai-emlyon/eml/model_eval.py +++ b/ai-emlyon/eml/model_eval.py @@ -5,36 +5,57 @@ import numpy as np import itertools -def classifier_metrics (y_test=None, y_preds=None, average='weighted', model=str, zero_division=0): - """Return Accuracy, Recall, Precision and F-1 score. - Average can take two arguments : macro or weighted """ + +def classifier_metrics( + y_test=None, y_preds=None, average="weighted", model=str, zero_division=0 +): + """Return Accuracy, Recall, Precision and F-1 score. + Average can take two arguments : macro or weighted""" acc = metrics.accuracy_score(y_test, y_preds) - rec = metrics.recall_score(y_test, y_preds, average = average, zero_division=zero_division) - prc = metrics.precision_score(y_test, y_preds, average = average, zero_division=zero_division) - f1 = metrics.f1_score(y_test, y_preds, average = average, zero_division=zero_division) - print (f"{model} Classification Metrics :") - print ("-------------------") - print('Accuracy : {:.2f}%'.format(acc*100)) - print('Recall : {:.2f}%'.format(rec*100)) - print('Precision : {:.2f}%'.format(prc*100)) - print('F1-score : {:.2f}%'.format(f1*100)) - print('\n') - -def get_classification_report(y_test, y_pred, sortby='f1-score', model=str): + rec = metrics.recall_score( + y_test, y_preds, average=average, zero_division=zero_division + ) + prc = metrics.precision_score( + y_test, y_preds, average=average, zero_division=zero_division + ) + f1 = metrics.f1_score(y_test, y_preds, average=average, zero_division=zero_division) + print(f"{model} Classification Metrics :") + print("-------------------") + print("Accuracy : {:.2f}%".format(acc * 100)) + print("Recall : {:.2f}%".format(rec * 100)) + print("Precision : {:.2f}%".format(prc * 100)) + print("F1-score : {:.2f}%".format(f1 * 100)) + print("\n") + + +def get_classification_report(y_test, y_pred, sortby="f1-score", model=str): """Return a classification report as pd.DataFrame""" - report = metrics.classification_report(y_test, y_pred, output_dict=True, zero_division=0) + report = metrics.classification_report( + y_test, y_pred, output_dict=True, zero_division=0 + ) df_classification_report = pd.DataFrame(report).transpose() - df_classification_report = df_classification_report.sort_values(by=[sortby], ascending=False) - df_classification_report.rename(columns={colname: model + '_' + colname for colname in df_classification_report.columns}, inplace=True) + df_classification_report = df_classification_report.sort_values( + by=[sortby], ascending=False + ) + df_classification_report.rename( + columns={ + colname: model + "_" + colname + for colname in df_classification_report.columns + }, + inplace=True, + ) return df_classification_report.round(2) -def plot_confusion_matrix(cm, - target_names=None, - title='Confusion matrix', - cmap=None, - normalize=True, - figsize=(10,10)): + +def plot_confusion_matrix( + cm, + target_names=None, + title="Confusion matrix", + cmap=None, + normalize=True, + figsize=(10, 10), +): """ given a sklearn confusion matrix (cm), make a nice plot @@ -68,19 +89,19 @@ def plot_confusion_matrix(cm, """ - accuracy = np.trace(cm) / np.sum(cm).astype('float') + accuracy = np.trace(cm) / np.sum(cm).astype("float") misclass = 1 - accuracy if cmap is None: - cmap = plt.get_cmap('Blues') + cmap = plt.get_cmap("Blues") if normalize: - cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] + cm = cm.astype("float") / cm.sum(axis=1)[:, np.newaxis] plt.figure(figsize=figsize) - plt.imshow(cm, interpolation='nearest', cmap=cmap) + plt.imshow(cm, interpolation="nearest", cmap=cmap) plt.title(title) - #plt.colorbar() + # plt.colorbar() if target_names is not None: tick_marks = np.arange(len(target_names)) @@ -90,15 +111,25 @@ def plot_confusion_matrix(cm, thresh = cm.max() / 1.5 if normalize else cm.max() / 2 for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])): if normalize: - plt.text(j, i, "{:0.3f}".format(cm[i, j]), - horizontalalignment="center", - color="white" if cm[i, j] > thresh else "black") + plt.text( + j, + i, + "{:0.3f}".format(cm[i, j]), + horizontalalignment="center", + color="white" if cm[i, j] > thresh else "black", + ) else: - plt.text(j, i, "{:,}".format(cm[i, j]), - horizontalalignment="center", - color="white" if cm[i, j] > thresh else "black") + plt.text( + j, + i, + "{:,}".format(cm[i, j]), + horizontalalignment="center", + color="white" if cm[i, j] > thresh else "black", + ) plt.tight_layout() - plt.ylabel('True label') - plt.xlabel('Predicted label\naccuracy={:0.4f}; misclass={:0.4f}'.format(accuracy, misclass)) - plt.show() \ No newline at end of file + plt.ylabel("True label") + plt.xlabel( + "Predicted label\naccuracy={:0.4f}; misclass={:0.4f}".format(accuracy, misclass) + ) + plt.show() diff --git a/ai-emlyon/eml/taxonomy.py b/ai-emlyon/eml/taxonomy.py index 7737f4a5..e2607cab 100644 --- a/ai-emlyon/eml/taxonomy.py +++ b/ai-emlyon/eml/taxonomy.py @@ -1,61 +1,75 @@ -#Imports +# Imports import re import numpy as np import seaborn as sns from robotoff.products import ProductDataset from robotoff.taxonomy import get_taxonomy -#Setup +# Setup sns.set() ds = ProductDataset.load() -taxonomy = get_taxonomy('category') +taxonomy = get_taxonomy("category") -def get_taxonomy_info (category=str, info_type='tax_list'): - + +def get_taxonomy_info(category=str, info_type="tax_list"): """ Return taxonomy info for a category tag, with cleaned nodes strings. - Can be a list of strings with all tags in the taxonomy branch or - a dict of integers with distances within the taxonomy branch + Can be a list of strings with all tags in the taxonomy branch or + a dict of integers with distances within the taxonomy branch If the taxonomy does not exist for the category filled, return np.nan - + Parameters ------------ - - category = str, mandatory + - category = str, mandatory Category tag to explore. - info_type = 'childs', 'parents', full_tax_list', 'tax_distance' - tax_list : Return a list with all the nodes (childs and parents) within the category taxonomy. - tax_distance : Return a dict with nb_parents, nb_childs and total nodes. - + """ - + nonetype = type(None) cat_taxonomy = taxonomy[category] - - #If taxonomy exist, get information + + # If taxonomy exist, get information if not isinstance(cat_taxonomy, nonetype): - cat_childs_list = [re.sub(' 1: + # return pnns 1 or more + elif len(pnns_candidates) > 1: output = pnns_candidates[pnns_index] - - #---- Search duplicates option ---- - #If pnns already exist in another group + # ---- Search duplicates option ---- + + # If pnns already exist in another group if search_duplicates: - existing_values = [row[f'pnns_groups_{i}'] for i in range (1, nb_groups)] - if output in existing_values: - #try to find another pnns + existing_values = [row[f"pnns_groups_{i}"] for i in range(1, nb_groups)] + if output in existing_values: + # try to find another pnns for i in range(len(pnns_candidates)): output = pnns_candidates[i] if output not in existing_values: break - #return output + # return output return output + def find_pnns_groups_1(df): """ Find unknown pnns groups 1 based on known pnns groups 2. """ data = df.copy() vals_to_find = list(data.pnns_groups_1.unique()) - vals_to_find.remove('unknown') + vals_to_find.remove("unknown") for val in vals_to_find: - group_2_vals = list(data['pnns_groups_2'].loc[data['pnns_groups_1'] == val].unique()) - data['pnns_groups_1'].loc[(data['pnns_groups_1'] == 'unknown') & - (data['pnns_groups_2'].isin(group_2_vals))] = val + group_2_vals = list( + data["pnns_groups_2"].loc[data["pnns_groups_1"] == val].unique() + ) + data["pnns_groups_1"].loc[ + (data["pnns_groups_1"] == "unknown") + & (data["pnns_groups_2"].isin(group_2_vals)) + ] = val return data + def get_ingredients_columns(df, ingredients_list): """Computational expensive Create a new column for each item in ing_list - and fill with percent_estimate or median if the row contains ingredient and 0 if not""" - data = df.copy(deep=True) #make a copy of df - for i in ingredients_list: data[i] = 0 #create 0 columns, 1 per ingredient - for ingre_list, index in zip(data.ingredients, data.index): #loop over df rows - for ingre_dict in ingre_list: #loop continue in the dicts in each row - val = ingre_dict['text'] #get ingredient text - val_clean = val.replace('_','').replace('-','').strip('').lower() - if val_clean in ingredients_list: #check if val is in columns added + and fill with percent_estimate or median if the row contains ingredient and 0 if not + """ + data = df.copy(deep=True) # make a copy of df + for i in ingredients_list: + data[i] = 0 # create 0 columns, 1 per ingredient + for ingre_list, index in zip(data.ingredients, data.index): # loop over df rows + for ingre_dict in ingre_list: # loop continue in the dicts in each row + val = ingre_dict["text"] # get ingredient text + val_clean = val.replace("_", "").replace("-", "").strip("").lower() + if val_clean in ingredients_list: # check if val is in columns added try: - data.loc[index,val_clean] = ingre_dict['percent_estimate'] #if yes, replace by estimate + data.loc[index, val_clean] = ingre_dict[ + "percent_estimate" + ] # if yes, replace by estimate except: - pnns_val = data.loc[index,'pnns_groups_1'] - val_median = data[val_clean].loc[data['pnns_groups_1'] == pnns_val].median() - data.loc[index,val_clean] = val_median - return data #return modified dataframe + pnns_val = data.loc[index, "pnns_groups_1"] + val_median = ( + data[val_clean].loc[data["pnns_groups_1"] == pnns_val].median() + ) + data.loc[index, val_clean] = val_median + return data # return modified dataframe -#Older function +# Older function def depreciated_find_pnns(row, df_taxonomy): output = np.nan - tags_list = row.split(',') + tags_list = row.split(",") tax_list = get_taxonomy_info(tags_list[0]) if not isinstance(tax_list, float): - tax_list = tax_list + tags_list - tax_list = [tax.strip(' ').strip('') for tax in tax_list] + tax_list = tax_list + tags_list + tax_list = [tax.strip(" ").strip("") for tax in tax_list] tax_list = list(set(tax_list)) for suggestion, pnns in zip(df_taxonomy.taxonomy_suggestion, df_taxonomy.pnns): if suggestion in tax_list: @@ -212,13 +249,16 @@ def depreciated_find_pnns(row, df_taxonomy): else: continue if isinstance(output, float): - for possibilities, pnns in zip(df_taxonomy.all_taxonomy_possibilities, df_tax.pnns): + for possibilities, pnns in zip( + df_taxonomy.all_taxonomy_possibilities, df_tax.pnns + ): for possibility in possibilities.keys(): if possibility in tax_list: output = pnns break - else : + else: continue return output -#----------------------------END---------------------------- \ No newline at end of file + +# ----------------------------END---------------------------- diff --git a/ai-emlyon/food_model/evaluator.py b/ai-emlyon/food_model/evaluator.py index 95c9e2a1..cc3812e0 100644 --- a/ai-emlyon/food_model/evaluator.py +++ b/ai-emlyon/food_model/evaluator.py @@ -4,10 +4,12 @@ import seaborn as sns import pickle from sklearn import metrics + sns.set() -class Evaluator(): - + +class Evaluator: + def __init__(self): self.data = None @@ -15,17 +17,20 @@ def __init__(self): self.class_report = None def __repr__(self): - return 'Interpreter object' - + return "Interpreter object" + def build_data( - self, y_true, y_pred, + self, + y_true, + y_pred, y_confidence=None, - pred_type=None, - decode_labels=False, label_encoder='from_existing' - ): + pred_type=None, + decode_labels=False, + label_encoder="from_existing", + ): """ Create the dataset used for other methods - + Arguments --------- - y_true = true labels @@ -37,152 +42,178 @@ def build_data( decode_labels = If true, use a label encoder to decode labels. label_encoder = The label encoder used for decode labels if decode_labels == True. - if label_encoder = 'from_existing' : load the label encoder for the pred type passed in pred_type. - + Return ------ A dataframe with columns : - - y_true, - - y_pred, + - y_true, + - y_pred, - pred_is_true : 1 if y_true == y_pred else 0 - pred_confidence : y_confidence if y_confidence is not None """ - + self.pred_type = pred_type self.y_pred = y_pred self.y_true = y_true self.y_confidence = y_confidence self._decode_labels = decode_labels - + if self._decode_labels: - if label_encoder == 'from_existing': - if pred_type == 'G1': - pkl_file = open('label_encoder_g1.pkl', 'rb') + if label_encoder == "from_existing": + if pred_type == "G1": + pkl_file = open("label_encoder_g1.pkl", "rb") self.le = pickle.load(pkl_file) pkl_file.close() - elif pred_type == 'G2': - pkl_file = open('label_encoder_g2.pkl', 'rb') - self.le = pickle.load(pkl_file) + elif pred_type == "G2": + pkl_file = open("label_encoder_g2.pkl", "rb") + self.le = pickle.load(pkl_file) pkl_file.close() elif label_encoder is not None: self.le = label_encoder - df = pd.DataFrame({'y_true':self.y_true, 'y_pred':self.y_pred}) - df['pred_is_true'] = df.apply( - lambda x: 1 if x['y_true'] == x['y_pred'] else 0, axis=1) + df = pd.DataFrame({"y_true": self.y_true, "y_pred": self.y_pred}) + df["pred_is_true"] = df.apply( + lambda x: 1 if x["y_true"] == x["y_pred"] else 0, axis=1 + ) if self.y_confidence is not None: - df['pred_confidence'] = self.y_confidence + df["pred_confidence"] = self.y_confidence if self._decode_labels: - df['y_true'] = self.le.inverse_transform(df['y_true']) - df['y_pred'] = self.le.inverse_transform(df['y_pred'].astype(int)) - + df["y_true"] = self.le.inverse_transform(df["y_true"]) + df["y_pred"] = self.le.inverse_transform(df["y_pred"].astype(int)) + self.data = df self._data_ready = True return self.data def classification_report( - self, sortby='precision', name='model', - save_report=False, report_path='classification_report.csv' - ): + self, + sortby="precision", + name="model", + save_report=False, + report_path="classification_report.csv", + ): """Return a classification report as pd.DataFrame""" if not self._data_ready: self.build_data() - + report = metrics.classification_report( self.data.y_true, self.data.y_pred, output_dict=True, zero_division=0 - ) + ) df_report = pd.DataFrame(report).transpose() df_report = df_report.sort_values(by=[sortby], ascending=False) df_report.rename( - columns={colname: name + '_' + colname for colname in df_report.columns}, inplace=True - ) + columns={colname: name + "_" + colname for colname in df_report.columns}, + inplace=True, + ) self.class_report = df_report.round(2) - + if save_report: df_report.to_csv(report_path, index=True, header=True) - + return self.class_report - - def global_metrics( - self, average='weighted', name='Model', zero_div=0 - ): - """Return Accuracy, Recall, Precision and F-1 score. - Average can take two arguments : macro or weighted """ + + def global_metrics(self, average="weighted", name="Model", zero_div=0): + """Return Accuracy, Recall, Precision and F-1 score. + Average can take two arguments : macro or weighted""" if not self._data_ready: self.build_data() self._data_ready = True acc = metrics.accuracy_score(self.data.y_true, self.data.y_pred) - rec = metrics.recall_score(self.data.y_true, self.data.y_pred, average=average, zero_division=zero_div) + rec = metrics.recall_score( + self.data.y_true, self.data.y_pred, average=average, zero_division=zero_div + ) prc = metrics.precision_score( self.data.y_true, self.data.y_pred, average=average, zero_division=zero_div - ) - f1 = metrics.f1_score(self.data.y_true, self.data.y_pred, average=average, zero_division=zero_div) - print (f"{name} Classification Metrics :") - print ("-"*(len(name)+25)) - print('Accuracy : {:.2f}%'.format(acc*100)) - print('Recall : {:.2f}%'.format(rec*100)) - print('Precision : {:.2f}%'.format(prc*100)) - print('F1-score : {:.2f}%'.format(f1*100)) - print('\n') + ) + f1 = metrics.f1_score( + self.data.y_true, self.data.y_pred, average=average, zero_division=zero_div + ) + print(f"{name} Classification Metrics :") + print("-" * (len(name) + 25)) + print("Accuracy : {:.2f}%".format(acc * 100)) + print("Recall : {:.2f}%".format(rec * 100)) + print("Precision : {:.2f}%".format(prc * 100)) + print("F1-score : {:.2f}%".format(f1 * 100)) + print("\n") def plot_categories_scores( - self, metric='precision', name='Model', figsize=(8,10), - save_fig=False, fig_path="score_by_category.png" - ): + self, + metric="precision", + name="Model", + figsize=(8, 10), + save_fig=False, + fig_path="score_by_category.png", + ): """Point plot with metric score by category, sorted by metric""" if not self._data_ready: self.build_data() report = self.classification_report(sortby=metric, name=name) plt.figure(figsize=figsize) - sns.pointplot(y=report.index, x=report[f'{name}_{metric}'], palette='terrain') - plt.title(f'{name} {self.pred_type} {metric} by category', fontsize=14) + sns.pointplot(y=report.index, x=report[f"{name}_{metric}"], palette="terrain") + plt.title(f"{name} {self.pred_type} {metric} by category", fontsize=14) if save_fig: - sns.savefig(fig_path) + sns.savefig(fig_path) def plot_confidence( - self, name='Model', metric='precision', - col_wrap=5, save_fig=False, - fig_path="confidence_by_category.png" - ): + self, + name="Model", + metric="precision", + col_wrap=5, + save_fig=False, + fig_path="confidence_by_category.png", + ): """Print a KDE Plot with model confidence for every category""" - + if not self._data_ready: self.build_data() - self._data_ready = True + self._data_ready = True report = self.classification_report(sortby=metric, name=name) labels = self.data.y_true.unique() - hue_palette = {0:'darkred', 1:'darkgreen'} + hue_palette = {0: "darkred", 1: "darkgreen"} g = sns.FacetGrid( - self.data, col='y_true', hue='pred_is_true', - col_wrap=col_wrap, height=5,palette=hue_palette, xlim=(0,1) - ) - g.map(sns.kdeplot, 'pred_confidence', fill=True, common_norm=True, alpha=.4) + self.data, + col="y_true", + hue="pred_is_true", + col_wrap=col_wrap, + height=5, + palette=hue_palette, + xlim=(0, 1), + ) + g.map(sns.kdeplot, "pred_confidence", fill=True, common_norm=True, alpha=0.4) g.add_legend() for ax, label in zip(g.axes.flat, labels): ax.set_title(f'{label} | {metric} : {report[f"{name}_{metric}"][label]}') - + if save_fig: sns.savefig(fig_path) - + def plot_confusion_matrix( - self, name='Model', figsize=(20,15), annot=True, cmap='Greens', - save_fig=False, fig_path="confidence_by_category.png" - ): + self, + name="Model", + figsize=(20, 15), + annot=True, + cmap="Greens", + save_fig=False, + fig_path="confidence_by_category.png", + ): """Return a confusion matrix with seaborn heatmap design""" - + if not self._data_ready: self.build_data() - self._data_ready = True + self._data_ready = True cm = metrics.confusion_matrix(self.data.y_true, self.data.y_pred) labels = sorted(set(self.data.y_true)) plt.figure(figsize=figsize) - plot = sns.heatmap(cm, xticklabels=labels, yticklabels=labels, annot=annot, cmap=cmap, fmt='g') + plot = sns.heatmap( + cm, xticklabels=labels, yticklabels=labels, annot=annot, cmap=cmap, fmt="g" + ) plot.set_title(f"{name} {self.pred_type} Confusion Matrix") - + if save_fig: - sns.savefig(fig_path) \ No newline at end of file + sns.savefig(fig_path) diff --git a/ai-emlyon/food_model/xgfood.py b/ai-emlyon/food_model/xgfood.py index efe247fe..4d95985e 100644 --- a/ai-emlyon/food_model/xgfood.py +++ b/ai-emlyon/food_model/xgfood.py @@ -7,13 +7,14 @@ import pandas as pd from xgboost import XGBClassifier -class XGFood(): + +class XGFood: def __init__(self): self.model_G1 = XGBClassifier() - self.model_G1.load_model(r'files\xgboost_G1_m2.model') + self.model_G1.load_model(r"files\xgboost_G1_m2.model") self.model_G2 = XGBClassifier() - self.model_G2.load_model(r'files\xgboost_G2_m1.model') + self.model_G2.load_model(r"files\xgboost_G2_m1.model") self.le_G1 = None self.le_G2 = None self._unknown_code_G1 = 9 @@ -24,11 +25,15 @@ def __init__(self): self.X = None def process_X( - self, X_raw, ingredients_column='ingredients', text_column='product_name',verbose=True - ): + self, + X_raw, + ingredients_column="ingredients", + text_column="product_name", + verbose=True, + ): """ Given a structured ingredients column and the column 'product_name', - split information into 948 features (450 most frequent ingredients + split information into 948 features (450 most frequent ingredients and 488 most frequent words in product_name.) Parameters @@ -43,31 +48,36 @@ def process_X( ------ 1. Create an empty dataframe with the name of all features used to train XGBoost (size is 938*N samples) 2. Split ingredients related features and text related features - 3. For each sample, append features related to ingredient + 3. For each sample, append features related to ingredient with percent estimate founded in ingredients col - 4. For each sample, append features related to product_name - with dummies (1 if the word is in product_name, 0 else) + 4. For each sample, append features related to product_name + with dummies (1 if the word is in product_name, 0 else) Return ------ Change self.X to Processed X and self._X_processed from False to True. """ - + time_proc_start = time.time() - with open(r'files\features.json') as json_file: + with open(r"files\features.json") as json_file: features = json.load(json_file) - X_empty = pd.DataFrame(columns=features['features_cols']) + X_empty = pd.DataFrame(columns=features["features_cols"]) self.ingredients_column = ingredients_column self.text_column = text_column - if not self._feat_cols_filled: self.features_cols = X_empty.columns.to_list() - self.raw_cols_for_ings = [col for col in self.features_cols if 'ing_' in col] - self.cols_for_ings = [re.sub('ing_', '', col) for col in self.raw_cols_for_ings] - self.cols_for_text = [col for col in self.features_cols if 'ing_' not in col] + self.raw_cols_for_ings = [ + col for col in self.features_cols if "ing_" in col + ] + self.cols_for_ings = [ + re.sub("ing_", "", col) for col in self.raw_cols_for_ings + ] + self.cols_for_text = [ + col for col in self.features_cols if "ing_" not in col + ] self._feat_cols_filled = True - + self.X_raw = X_empty.append(X_raw).fillna(0) self.process_ingredients() self.process_names() @@ -75,86 +85,106 @@ def process_X( self._X_processed = True time_proc_end = (time.time() - time_proc_start) / 60 - if verbose: print(f'Processing done - Total running time : {round(time_proc_end,2)}mn') + if verbose: + print(f"Processing done - Total running time : {round(time_proc_end,2)}mn") return self.X def process_ingredients(self): - #Convert object to original structured list of dictionnaries - self.X_raw[self.ingredients_column] = self.X_raw[self.ingredients_column].apply(literal_eval) - #Loop in the list of dicts - for ingre_list, index in zip(self.X_raw[self.ingredients_column], self.X_raw.index): - #Loop in the dicts + # Convert object to original structured list of dictionnaries + self.X_raw[self.ingredients_column] = self.X_raw[self.ingredients_column].apply( + literal_eval + ) + # Loop in the list of dicts + for ingre_list, index in zip( + self.X_raw[self.ingredients_column], self.X_raw.index + ): + # Loop in the dicts for ingre_dict in ingre_list: - #get text - val = ingre_dict['text'] - #Clean text - val_clean = val.replace('_','').replace('-','').strip('').lower() - #Check if the ingredient is in features + # get text + val = ingre_dict["text"] + # Clean text + val_clean = val.replace("_", "").replace("-", "").strip("").lower() + # Check if the ingredient is in features if val_clean in self.cols_for_ings: - #Try to append with percent_estimate + # Try to append with percent_estimate try: - self.X_raw.loc[index,'ing_'+val_clean] = ingre_dict['percent_estimate'] - #If not working, try to append with percent_min + self.X_raw.loc[index, "ing_" + val_clean] = ingre_dict[ + "percent_estimate" + ] + # If not working, try to append with percent_min except: try: - self.X_raw.loc[index,'ing_'+val_clean] = ingre_dict['percent_min'] - #If not working, append 1 + self.X_raw.loc[index, "ing_" + val_clean] = ingre_dict[ + "percent_min" + ] + # If not working, append 1 except: - self.X_raw.loc[index,'ing_'+val_clean] = 1 - #Finally drop original ingredients col + self.X_raw.loc[index, "ing_" + val_clean] = 1 + # Finally drop original ingredients col self.X_raw.drop(columns=[self.ingredients_column], inplace=True) def process_names(self): - #Loop in text column + # Loop in text column for text, index in zip(self.X_raw[self.text_column], self.X_raw.index): row_text = text.lower() - #Loop in words features selected + # Loop in words features selected for word in self.cols_for_text: - #If the word is in text, append 1 + # If the word is in text, append 1 if word in row_text: - self.X_raw.loc[index,word] = 1 - #Finally drop original text col + self.X_raw.loc[index, word] = 1 + # Finally drop original text col self.X_raw = self.X_raw.drop(columns=self.text_column) - + def filter_preds(self, y_probas, tresholds): - """ - Filter pred with confidence treshold. + """ + Filter pred with confidence treshold. Fill unkeeped pred with 'y_unkown'. """ preds = [] y_unknown_code = len(tresholds) for p in y_probas: result = np.argwhere(p > tresholds) - if result.size > 0 : preds.append(int(result[0])) - else : preds.append(y_unknown_code) + if result.size > 0: + preds.append(int(result[0])) + else: + preds.append(y_unknown_code) return np.array(preds) def get_confidences(self, y_pred, y_probas, unk_code=9): - """ Get confidence -range (0,1)- for the label predicted""" + """Get confidence -range (0,1)- for the label predicted""" probas = [] for index, label in enumerate(y_pred): - if label < unk_code : - probas.append(y_probas[index,label]) - else : probas.append(0.0) - confidences = np.array(probas).reshape(-1,1) + if label < unk_code: + probas.append(y_probas[index, label]) + else: + probas.append(0.0) + confidences = np.array(probas).reshape(-1, 1) return confidences - def decode_labels(self, preds, label_dict): + def decode_labels(self, preds, label_dict): preds_decoded = [] - for pred in preds: preds_decoded.append(label_dict[str(pred)]) + for pred in preds: + preds_decoded.append(label_dict[str(pred)]) return np.array(preds_decoded) - def predict(self, X, decode_labels=True, pred_format='pd', get_confidence=True, preprocess=True): + def predict( + self, + X, + decode_labels=True, + pred_format="pd", + get_confidence=True, + preprocess=True, + ): """ Get predictions of G1 and G2 with XGBoost models, with optional confidences levels. - + Parameters ---------- - X (pd.DataFrame or np.array) : Features inputs. + X (pd.DataFrame or np.array) : Features inputs. Should have one structured column "ingredients" and a column "product_name" Can be already preprocessed (shape N samples, 938 features, set preprocess=True) decode_labels(Bool, default = True) : Return original labels (string) instead of encoded (int). - pred_format(default = 'pd'): + pred_format(default = 'pd'): 'pd' : Return a dataframe with predictions for G1, G2 and confidences if get_confidence set to True 'np' : return a numpy array with predictions and confidences if get_confidence set to True. get_confidence(Bool, default = True) : Get level of confidence for each prediction. @@ -168,56 +198,64 @@ def predict(self, X, decode_labels=True, pred_format='pd', get_confidence=True, - if get confidences set to True: - y_conf_G1 : level of confidence for every G1 prediction - y_conf_G2 : level of confidence for every G2 prediction - + """ - if preprocess: self.process_X(X) - else: self.X = X + if preprocess: + self.process_X(X) + else: + self.X = X - with open(r'files\tresholds_G2.json') as json_file: + with open(r"files\tresholds_G2.json") as json_file: tresholds_G2 = json.load(json_file) - + tresholds_G1 = np.array([0.6, 0.5, 0.5, 0.6, 0.5, 0.5, 0.5, 0.5, 0.5]) y_probas_G1 = self.model_G1.predict_proba(self.X) y_preds_G1 = self.filter_preds(y_probas_G1, tresholds=tresholds_G1) y_conf_G1 = self.get_confidences(y_preds_G1, y_probas_G1, unk_code=9) - X_G2 = np.append(self.X, y_preds_G1.reshape(-1,1), axis=1) + X_G2 = np.append(self.X, y_preds_G1.reshape(-1, 1), axis=1) - tresholds_G2 =np.array(list(tresholds_G2.values())) + tresholds_G2 = np.array(list(tresholds_G2.values())) y_probas_G2 = self.model_G2.predict_proba(X_G2) y_preds_G2 = self.filter_preds(y_probas_G2, tresholds=tresholds_G2) y_conf_G2 = self.get_confidences(y_preds_G2, y_probas_G2, unk_code=38) - + if decode_labels: - with open(r'files\labels_G1_code_reference.json') as json_file: + with open(r"files\labels_G1_code_reference.json") as json_file: self.le_G1 = json.load(json_file) - with open(r'files\labels_G2_code_reference.json') as json_file: + with open(r"files\labels_G2_code_reference.json") as json_file: self.le_G2 = json.load(json_file) y_preds_G1 = self.decode_labels(y_preds_G1, self.le_G1) y_preds_G2 = self.decode_labels(y_preds_G2, self.le_G2) - if pred_format == 'np': - if get_confidence : - predictions_G1 = np.append(y_preds_G1.reshape(-1,1), y_conf_G1, axis=1) - predictions_G2 = np.append(y_preds_G2.reshape(-1,1), y_conf_G2, axis=1) + if pred_format == "np": + if get_confidence: + predictions_G1 = np.append(y_preds_G1.reshape(-1, 1), y_conf_G1, axis=1) + predictions_G2 = np.append(y_preds_G2.reshape(-1, 1), y_conf_G2, axis=1) predictions = np.append(predictions_G1, predictions_G2, axis=1) else: - predictions = np.append(y_preds_G1.reshape(-1,1), y_preds_G2.reshape(-1,1), axis=1) - - elif pred_format == 'pd': - if get_confidence: - predictions = pd.DataFrame({ - 'y_pred_G1':y_preds_G1, - 'y_conf_G1':y_conf_G1.round(2).flatten(), - 'y_pred_G2':y_preds_G2, - 'y_conf_G2':y_conf_G2.round(2).flatten(), - }) + predictions = np.append( + y_preds_G1.reshape(-1, 1), y_preds_G2.reshape(-1, 1), axis=1 + ) + + elif pred_format == "pd": + if get_confidence: + predictions = pd.DataFrame( + { + "y_pred_G1": y_preds_G1, + "y_conf_G1": y_conf_G1.round(2).flatten(), + "y_pred_G2": y_preds_G2, + "y_conf_G2": y_conf_G2.round(2).flatten(), + } + ) else: - predictions = pd.DataFrame({ - 'y_pred_G1':y_preds_G1, - 'y_pred_G2':y_preds_G2, - }) - + predictions = pd.DataFrame( + { + "y_pred_G1": y_preds_G1, + "y_pred_G2": y_preds_G2, + } + ) + self.predictions = predictions return predictions diff --git a/ai-emlyon/scripts/lucain_script_taxonomy.py b/ai-emlyon/scripts/lucain_script_taxonomy.py index f145e88f..99f9aff6 100644 --- a/ai-emlyon/scripts/lucain_script_taxonomy.py +++ b/ai-emlyon/scripts/lucain_script_taxonomy.py @@ -1,8 +1,8 @@ -#----------------------------LUCAIN SCRIPT FOR TAXONOMY & PNNS---------------------------- +# ----------------------------LUCAIN SCRIPT FOR TAXONOMY & PNNS---------------------------- import json -#if not installed : pip install python-Levenshtein +# if not installed : pip install python-Levenshtein from Levenshtein import distance as levenshtein_distance from robotoff.taxonomy import get_taxonomy @@ -97,4 +97,4 @@ def _clean_str(s): # export to JSON with open("taxonomy_pnns.json", "w") as f: - json.dump(output, f, indent=2, sort_keys=True) \ No newline at end of file + json.dump(output, f, indent=2, sort_keys=True) diff --git a/ai-emlyon/scripts/merging_script.py b/ai-emlyon/scripts/merging_script.py index 4942bd30..4f622dc8 100644 --- a/ai-emlyon/scripts/merging_script.py +++ b/ai-emlyon/scripts/merging_script.py @@ -1,100 +1,174 @@ -#__________________________Merging Predicitions with OFF Database__________________________ +# __________________________Merging Predicitions with OFF Database__________________________ -#---------------Libraries--------------- +# ---------------Libraries--------------- import pandas as pd from pandas.io.json import json_normalize -#---------------Reading files--------------- - -#Read OFF full database -full_database = pd.read_csv(r'C:\Users\Antoine\Coding Bootcamp\machine learning\ -Open Food Facts\data\en.openfoodfacts.org.products.csv', low_memory = False, -sep='\t',error_bad_lines=False) - -#Read all robotoff predictions files -pred1 = pd.read_json(r'C:\Users\Antoine\Coding Bootcamp\machine learning\Open Food Facts\robotoff_predictions\dump_page1_by_100000.json') -pred2 = pd.read_json(r'C:\Users\Antoine\Coding Bootcamp\machine learning\Open Food Facts\robotoff_predictions\dump_page2_by_100000.json') -pred3 = pd.read_json(r'C:\Users\Antoine\Coding Bootcamp\machine learning\Open Food Facts\robotoff_predictions\dump_page3_by_100000.json') -pred4 = pd.read_json(r'C:\Users\Antoine\Coding Bootcamp\machine learning\Open Food Facts\robotoff_predictions\dump_page4_by_100000.json') -pred5 = pd.read_json(r'C:\Users\Antoine\Coding Bootcamp\machine learning\Open Food Facts\robotoff_predictions\dump_page5_by_100000.json') -pred6 = pd.read_json(r'C:\Users\Antoine\Coding Bootcamp\machine learning\Open Food Facts\robotoff_predictions\dump_page6_by_100000.json') -pred7 = pd.read_json(r'C:\Users\Antoine\Coding Bootcamp\machine learning\Open Food Facts\robotoff_predictions\dump_page7_by_100000.json') -pred8 = pd.read_json(r'C:\Users\Antoine\Coding Bootcamp\machine learning\Open Food Facts\robotoff_predictions\dump_page8_by_100000.json') - -#---------------Concat predictions files--------------- - -#Save files into a list to concat them +# ---------------Reading files--------------- + +# Read OFF full database +full_database = pd.read_csv( + r"C:\Users\Antoine\Coding Bootcamp\machine learning\ +Open Food Facts\data\en.openfoodfacts.org.products.csv", + low_memory=False, + sep="\t", + error_bad_lines=False, +) + +# Read all robotoff predictions files +pred1 = pd.read_json( + r"C:\Users\Antoine\Coding Bootcamp\machine learning\Open Food Facts\robotoff_predictions\dump_page1_by_100000.json" +) +pred2 = pd.read_json( + r"C:\Users\Antoine\Coding Bootcamp\machine learning\Open Food Facts\robotoff_predictions\dump_page2_by_100000.json" +) +pred3 = pd.read_json( + r"C:\Users\Antoine\Coding Bootcamp\machine learning\Open Food Facts\robotoff_predictions\dump_page3_by_100000.json" +) +pred4 = pd.read_json( + r"C:\Users\Antoine\Coding Bootcamp\machine learning\Open Food Facts\robotoff_predictions\dump_page4_by_100000.json" +) +pred5 = pd.read_json( + r"C:\Users\Antoine\Coding Bootcamp\machine learning\Open Food Facts\robotoff_predictions\dump_page5_by_100000.json" +) +pred6 = pd.read_json( + r"C:\Users\Antoine\Coding Bootcamp\machine learning\Open Food Facts\robotoff_predictions\dump_page6_by_100000.json" +) +pred7 = pd.read_json( + r"C:\Users\Antoine\Coding Bootcamp\machine learning\Open Food Facts\robotoff_predictions\dump_page7_by_100000.json" +) +pred8 = pd.read_json( + r"C:\Users\Antoine\Coding Bootcamp\machine learning\Open Food Facts\robotoff_predictions\dump_page8_by_100000.json" +) + +# ---------------Concat predictions files--------------- + +# Save files into a list to concat them pred_list = [pred1, pred2, pred3, pred4, pred5, pred6, pred7, pred8] -#Concat files into one df and normalize with json_normalize (for insights column) +# Concat files into one df and normalize with json_normalize (for insights column) raw_predictions = pd.concat(pred_list) -predictions = json_normalize(raw_predictions['insights']) - -#Drop unrelevant columns -cols_pred_to_drop = ['id', 'type', 'timestamp', 'latent', 'process_after', 'value','source_image', -'automatic_processing', 'server_domain', 'server_type','unique_scans_n', 'predictor', 'robotoff_reserved_barcode'] +predictions = json_normalize(raw_predictions["insights"]) + +# Drop unrelevant columns +cols_pred_to_drop = [ + "id", + "type", + "timestamp", + "latent", + "process_after", + "value", + "source_image", + "automatic_processing", + "server_domain", + "server_type", + "unique_scans_n", + "predictor", + "robotoff_reserved_barcode", +] predictions.drop(cols_pred_to_drop, axis=1, inplace=True) -#Add a prefix "robotoff" to the columns (so we can later identify predictions columns into the merged dataframe) -predictions.rename(columns={colname: 'robotoff_' + colname for colname in predictions.columns}, inplace=True) -predictions.rename(columns={'robotoff_barcode' : 'code'}, inplace=True) +# Add a prefix "robotoff" to the columns (so we can later identify predictions columns into the merged dataframe) +predictions.rename( + columns={colname: "robotoff_" + colname for colname in predictions.columns}, + inplace=True, +) +predictions.rename(columns={"robotoff_barcode": "code"}, inplace=True) -#---------------Merging files--------------- +# ---------------Merging files--------------- -#Merge predictions with full database -df_with_predictions = pd.merge(full_database, predictions, on=['code'], how='right') +# Merge predictions with full database +df_with_predictions = pd.merge(full_database, predictions, on=["code"], how="right") -#---------------Clean columns--------------- +# ---------------Clean columns--------------- -#Drop columns with more than 50% nans +# Drop columns with more than 50% nans cols = df_with_predictions.columns.to_list() threshold = int(len(df_with_predictions) * 0.5) -cols_removed = [] #cols removed are saved into a list +cols_removed = [] # cols removed are saved into a list -#We keep columns with category info, even if almost empty -for col in cols : - if (not 'catego' in col) and (df_with_predictions[col].count() < threshold) : +# We keep columns with category info, even if almost empty +for col in cols: + if (not "catego" in col) and (df_with_predictions[col].count() < threshold): df_with_predictions.drop([col], axis=1, inplace=True) cols_removed.append(col) -#Drop some other columns considered unrelevant for the benchmark -cols_to_drop = ['last_modified_t', 'last_modified_datetime', 'image_url', 'image_small_url', - 'image_nutrition_url', 'image_nutrition_small_url', 'energy-kcal_100g', - 'energy_100g', 'fat_100g', 'saturated-fat_100g', 'carbohydrates_100g', - 'sugars_100g', 'proteins_100g', 'salt_100g', 'sodium_100g', 'created_t', 'states_en'] +# Drop some other columns considered unrelevant for the benchmark +cols_to_drop = [ + "last_modified_t", + "last_modified_datetime", + "image_url", + "image_small_url", + "image_nutrition_url", + "image_nutrition_small_url", + "energy-kcal_100g", + "energy_100g", + "fat_100g", + "saturated-fat_100g", + "carbohydrates_100g", + "sugars_100g", + "proteins_100g", + "salt_100g", + "sodium_100g", + "created_t", + "states_en", +] cols_removed.append(cols_to_drop) -df_with_predictions.drop(columns=cols_to_drop, inplace=True, errors='ignore') - -#---------------Split dataframes--------------- - -#Create to dataframes : one with categories filled, one without -df_with_cats = df_with_predictions.dropna(subset=['categories', 'categories_tags', 'categories_en', 'main_category', 'main_category_en']) +df_with_predictions.drop(columns=cols_to_drop, inplace=True, errors="ignore") + +# ---------------Split dataframes--------------- + +# Create to dataframes : one with categories filled, one without +df_with_cats = df_with_predictions.dropna( + subset=[ + "categories", + "categories_tags", + "categories_en", + "main_category", + "main_category_en", + ] +) df_without_cats = df_with_predictions.drop(df_with_cats.index) -#---------------Note--------------- -#Predictions are for a category tag -#Some of the products have multiple tags in the OFF database -#Therefore it is difficult to compare them with predictions -#---------------------------------- - -#Create a col with the number of categories tags for every product -df_with_cats['nb_cats_tags'] = df_with_cats.apply(lambda x: len(x['categories_tags'].split(',')), axis=1) -#We split rows with multiple tags and create a new col "single tag" -single_category_tag = df_with_cats['categories_tags'].str.split(',').apply(pd.Series, 1).stack() +# ---------------Note--------------- +# Predictions are for a category tag +# Some of the products have multiple tags in the OFF database +# Therefore it is difficult to compare them with predictions +# ---------------------------------- + +# Create a col with the number of categories tags for every product +df_with_cats["nb_cats_tags"] = df_with_cats.apply( + lambda x: len(x["categories_tags"].split(",")), axis=1 +) +# We split rows with multiple tags and create a new col "single tag" +single_category_tag = ( + df_with_cats["categories_tags"].str.split(",").apply(pd.Series, 1).stack() +) single_category_tag.index = single_category_tag.index.droplevel(-1) -single_category_tag.name = 'single_category_tag' +single_category_tag.name = "single_category_tag" df_with_cats = df_with_cats.join(single_category_tag) -#---------------Export files to csv--------------- - -#Full merged dataframe -df_with_predictions.to_csv(r'C:\Users\Antoine\Coding Bootcamp\machine learning\ -Open Food Facts\data\merged_predictions_off_full.csv', index=False, header=True) -#Only products with filled categories -df_with_cats.to_csv(r'C:\Users\Antoine\Coding Bootcamp\machine learning\ -Open Food Facts\data\merged_predictions_off_with_categories.csv', index=False, header=True) -#Only products with categories columns empty -df_without_cats.to_csv(r'C:\Users\Antoine\Coding Bootcamp\machine learning\ -Open Food Facts\data\merged_predictions_off_without_categories.csv', index=False, header=True) - +# ---------------Export files to csv--------------- + +# Full merged dataframe +df_with_predictions.to_csv( + r"C:\Users\Antoine\Coding Bootcamp\machine learning\ +Open Food Facts\data\merged_predictions_off_full.csv", + index=False, + header=True, +) +# Only products with filled categories +df_with_cats.to_csv( + r"C:\Users\Antoine\Coding Bootcamp\machine learning\ +Open Food Facts\data\merged_predictions_off_with_categories.csv", + index=False, + header=True, +) +# Only products with categories columns empty +df_without_cats.to_csv( + r"C:\Users\Antoine\Coding Bootcamp\machine learning\ +Open Food Facts\data\merged_predictions_off_without_categories.csv", + index=False, + header=True, +) diff --git a/data4good_logo_detection/setup.py b/data4good_logo_detection/setup.py index 56ec183c..d1618024 100644 --- a/data4good_logo_detection/setup.py +++ b/data4good_logo_detection/setup.py @@ -1,3 +1,3 @@ from setuptools import setup, find_packages -setup(name="logo_detection", version='1.0', packages=find_packages()) \ No newline at end of file +setup(name="logo_detection", version="1.0", packages=find_packages()) diff --git a/data4good_logo_detection/src/data/explore_data.py b/data4good_logo_detection/src/data/explore_data.py index 4dc8325b..020ce629 100644 --- a/data4good_logo_detection/src/data/explore_data.py +++ b/data4good_logo_detection/src/data/explore_data.py @@ -1,7 +1,7 @@ def extract_values(obj, key): """Pull all values of specified key from nested JSON.""" arr = [] - + def extract(obj, arr, key): """Recursively search for values of key in JSON tree.""" if isinstance(obj, dict): @@ -14,6 +14,6 @@ def extract(obj, arr, key): for item in obj: extract(item, arr, key) return arr - + results = extract(obj, arr, key) return results diff --git a/ingredient_extraction/dataset-generation/clean_dataset.py b/ingredient_extraction/dataset-generation/clean_dataset.py index 4c21dedf..4b539cdd 100644 --- a/ingredient_extraction/dataset-generation/clean_dataset.py +++ b/ingredient_extraction/dataset-generation/clean_dataset.py @@ -118,9 +118,7 @@ def annotate(item: dict, existing_annotation: Optional[dict] = None): f"action='{existing_annotation['action']}', " f"updated_offsets={existing_annotation['updated_offsets']}" ) - marked_text = generate_highlighted_text( - item["text"], [list(x) for x in offsets] - ) + marked_text = generate_highlighted_text(item["text"], [list(x) for x in offsets]) marked_text_highlighted = marked_text.replace("", "[red]").replace( "", "[/red]" ) diff --git a/ingredient_extraction/model-analysis/evaluate_model.py b/ingredient_extraction/model-analysis/evaluate_model.py index e26b6fcb..35ca85aa 100644 --- a/ingredient_extraction/model-analysis/evaluate_model.py +++ b/ingredient_extraction/model-analysis/evaluate_model.py @@ -22,7 +22,9 @@ predictions = {} -with gzip.open(prediction_dir / f"{SPLIT_NAME}_predictions_agg_first.jsonl.gz", "rt") as f: +with gzip.open( + prediction_dir / f"{SPLIT_NAME}_predictions_agg_first.jsonl.gz", "rt" +) as f: for line in f: item = json.loads(line) id_ = item["meta"]["id"] @@ -48,7 +50,9 @@ gold_markup = generate_highlighted_text( gold_sample["text"], gold_offsets, html_escape=True, mark_token="mark" ) - predicted_offsets = [[entity["start"], entity["end"]] for entity in predict_sample["entities"]] + predicted_offsets = [ + [entity["start"], entity["end"]] for entity in predict_sample["entities"] + ] predict_markup = generate_highlighted_text( predict_sample["text"], @@ -67,7 +71,12 @@ if is_manually_annotated else automatically_annotated_html_list ) - predicted_offset_html = "
".join([f"\"{predict_sample['text'][start:end]}\" [{start}:{end}]" for start, end in predicted_offsets]) + predicted_offset_html = "
".join( + [ + f"\"{predict_sample['text'][start:end]}\" [{start}:{end}]" + for start, end in predicted_offsets + ] + ) html_list.append( f"

ID: {key}
Gold:
{gold_markup}
Predicted:
{predicted_offset_html}
{predict_markup}

" ) diff --git a/ingredient_extraction/model-analysis/test_prediction.py b/ingredient_extraction/model-analysis/test_prediction.py index 6b24febd..eb3b6e76 100644 --- a/ingredient_extraction/model-analysis/test_prediction.py +++ b/ingredient_extraction/model-analysis/test_prediction.py @@ -1,5 +1,10 @@ from datasets import load_dataset -from transformers import AutoModelForTokenClassification, AutoTokenizer, pipeline, TokenClassificationPipeline +from transformers import ( + AutoModelForTokenClassification, + AutoTokenizer, + pipeline, + TokenClassificationPipeline, +) import wandb ARTIFACT_NAME = ( diff --git a/ingredient_extraction/train/run_pred_colab.py b/ingredient_extraction/train/run_pred_colab.py index fb896fb9..ce6cf0fd 100644 --- a/ingredient_extraction/train/run_pred_colab.py +++ b/ingredient_extraction/train/run_pred_colab.py @@ -29,22 +29,22 @@ ) for split_name in ("test", "train"): - split_ds = dataset[split_name] - texts = split_ds["text"] - aggregated_outputs = classifier(texts, batch_size=16) - full_output = ( - { - "text": split_ds[i]["text"], - "meta": split_ds[i]["meta"], - "entities": entities, - } - for i, entities in enumerate(aggregated_outputs) - ) - prediction_file_path = Path(f"{split_name}_predictions_agg.jsonl.gz") - with gzip.open(prediction_file_path, "wb") as f: - f.write( - b"\n".join( - orjson.dumps(item, option=orjson.OPT_SERIALIZE_NUMPY) - for item in full_output - ) + split_ds = dataset[split_name] + texts = split_ds["text"] + aggregated_outputs = classifier(texts, batch_size=16) + full_output = ( + { + "text": split_ds[i]["text"], + "meta": split_ds[i]["meta"], + "entities": entities, + } + for i, entities in enumerate(aggregated_outputs) + ) + prediction_file_path = Path(f"{split_name}_predictions_agg.jsonl.gz") + with gzip.open(prediction_file_path, "wb") as f: + f.write( + b"\n".join( + orjson.dumps(item, option=orjson.OPT_SERIALIZE_NUMPY) + for item in full_output ) + ) diff --git a/logo-ann/benchmarks/ANN_benchmark/elasticsearch_benchmark.py b/logo-ann/benchmarks/ANN_benchmark/elasticsearch_benchmark.py index 0fb7b130..8409b1b9 100644 --- a/logo-ann/benchmarks/ANN_benchmark/elasticsearch_benchmark.py +++ b/logo-ann/benchmarks/ANN_benchmark/elasticsearch_benchmark.py @@ -16,307 +16,352 @@ from recall_computation import compute_recall -def create_index(es: Elasticsearch, index_name: str, d:int, ef_construction: int, m: int): - """ - Create an index for the Elasticsearch server and set the mappings for the embedding field. - - Parameters: - es (Elasticsearch): an instance of Elasticsearch class - index_name (str): the name of the index to create - d (int): the size of the embeddings - ef_construction (int): the maximum number of elements to consider during the construction of the index - m (int): the number of connections that each node in the index has to its neighbors - - Returns: - float: the time taken to create the index - """ - - start_time = time.monotonic() - - mappings = { - "properties": { - "external_id": {"type": "integer"}, - "embedding": { - "type": "dense_vector", - "dims": d, - "index": True, - "similarity": "dot_product", +def create_index( + es: Elasticsearch, index_name: str, d: int, ef_construction: int, m: int +): + """ + Create an index for the Elasticsearch server and set the mappings for the embedding field. + + Parameters: + es (Elasticsearch): an instance of Elasticsearch class + index_name (str): the name of the index to create + d (int): the size of the embeddings + ef_construction (int): the maximum number of elements to consider during the construction of the index + m (int): the number of connections that each node in the index has to its neighbors + + Returns: + float: the time taken to create the index + """ + + start_time = time.monotonic() + + mappings = { + "properties": { + "external_id": {"type": "integer"}, + "embedding": { + "type": "dense_vector", + "dims": d, + "index": True, + "similarity": "dot_product", "index_options": { - "type": "hnsw", - "m": m, - "ef_construction": ef_construction, - }, + "type": "hnsw", + "m": m, + "ef_construction": ef_construction, }, - }, - } - - try : - es.indices.delete(index=index_name) - except : - print("No index existing yet. Let's create one !") - - es.indices.create(index=index_name, mappings=mappings) + }, + }, + } - return time.monotonic() - start_time + try: + es.indices.delete(index=index_name) + except: + print("No index existing yet. Let's create one !") + es.indices.create(index=index_name, mappings=mappings) -def build_index(es: Elasticsearch, index_name: str, embeddings_path: pathlib.Path, d: int, batch_size : int, n_vec : int): - """ - Builds an Elasticsearch index from a set of embeddings + return time.monotonic() - start_time - Parameters: - es (Elasticsearch): Elasticsearch object to build index on - index_name (str): name of index to be created - embeddings_path (pathlib.Path): path to the file containing embeddings to be added to the index - d (int): dimension of each embedding - batch_size (int): number of embeddings to be added to the index at a time - n_vec (int): total number of embeddings to be added to the index - Returns: - building_time (int): time taken to build the index - """ +def build_index( + es: Elasticsearch, + index_name: str, + embeddings_path: pathlib.Path, + d: int, + batch_size: int, + n_vec: int, +): + """ + Builds an Elasticsearch index from a set of embeddings - building_time = 0 + Parameters: + es (Elasticsearch): Elasticsearch object to build index on + index_name (str): name of index to be created + embeddings_path (pathlib.Path): path to the file containing embeddings to be added to the index + d (int): dimension of each embedding + batch_size (int): number of embeddings to be added to the index at a time + n_vec (int): total number of embeddings to be added to the index - data_gen = get_embedding(embeddings_path, batch_size, n_vec) + Returns: + building_time (int): time taken to build the index + """ - unique_id = 0 + building_time = 0 - for batch in tqdm.tqdm(data_gen): - (embeddings_batch, external_id_batch) = batch - - for i in range(len(embeddings_batch)): - unique_id +=1 - elapsed = time.monotonic() - vector = embeddings_batch[i].reshape(1,d) - vector = normalize(vector) - vector = vector.reshape(d) + data_gen = get_embedding(embeddings_path, batch_size, n_vec) - external_id = int(external_id_batch[i]) + unique_id = 0 - doc = { - "embedding": vector, - } - es.index(index=index_name, id=external_id, document=doc) - building_time = building_time + time.monotonic() - elapsed + for batch in tqdm.tqdm(data_gen): + (embeddings_batch, external_id_batch) = batch - return building_time + for i in range(len(embeddings_batch)): + unique_id += 1 + elapsed = time.monotonic() + vector = embeddings_batch[i].reshape(1, d) + vector = normalize(vector) + vector = vector.reshape(d) + external_id = int(external_id_batch[i]) -def search_index( - es: Elasticsearch, - index_name: str, - embeddings_path: pathlib.Path, - queries: int, - batch_size: int, - K: np.array, - num_candidates: int, - ): - """ - Searches the Elasticsearch index and returns the nearest neighbours of a set of query vectors - - Parameters: - es (Elasticsearch): Elasticsearch object to search on - index_name (str): name of the index to search on - embeddings_path (pathlib.Path): path to the file containing query vectors - queries (int): number of query vectors - batch_size (int): number of query vectors to search at a time - K (np.array): array of top K nearest neighbours to retrieve - num_candidates (int): number of candidates to be considered during search - - Returns: - nearest_neighbours (dict): dictionary containing the nearest neighbours for each query vector - search_time (float): average time taken to search for each query vector - """ - - search_time = [] - - nearest_neighbours = {} - - data_gen = get_embedding(embeddings_path, batch_size, queries) - - for batch in tqdm.tqdm(data_gen): - (embeddings_batch, external_id_batch) = batch - - for i in range(len(embeddings_batch)): - time_before = time.monotonic() - vector = embeddings_batch[i].reshape(1,d) - vector = normalize(vector) - vector = vector.reshape(d) - - request = { - "knn": { - "field": "embedding", - "query_vector": vector, - "k": int(K.max())+1, - "num_candidates": num_candidates - }, - } - res = es.knn_search(index = index_name, body=request) - search_time.append(time.monotonic()-time_before) + doc = { + "embedding": vector, + } + es.index(index=index_name, id=external_id, document=doc) + building_time = building_time + time.monotonic() - elapsed - query_id = str(external_id_batch[i]) + return building_time - nearest_neighbours[query_id] = {} - hits = res['hits']['hits'] +def search_index( + es: Elasticsearch, + index_name: str, + embeddings_path: pathlib.Path, + queries: int, + batch_size: int, + K: np.array, + num_candidates: int, +): + """ + Searches the Elasticsearch index and returns the nearest neighbours of a set of query vectors + + Parameters: + es (Elasticsearch): Elasticsearch object to search on + index_name (str): name of the index to search on + embeddings_path (pathlib.Path): path to the file containing query vectors + queries (int): number of query vectors + batch_size (int): number of query vectors to search at a time + K (np.array): array of top K nearest neighbours to retrieve + num_candidates (int): number of candidates to be considered during search + + Returns: + nearest_neighbours (dict): dictionary containing the nearest neighbours for each query vector + search_time (float): average time taken to search for each query vector + """ + + search_time = [] + + nearest_neighbours = {} + + data_gen = get_embedding(embeddings_path, batch_size, queries) + + for batch in tqdm.tqdm(data_gen): + (embeddings_batch, external_id_batch) = batch + + for i in range(len(embeddings_batch)): + time_before = time.monotonic() + vector = embeddings_batch[i].reshape(1, d) + vector = normalize(vector) + vector = vector.reshape(d) + + request = { + "knn": { + "field": "embedding", + "query_vector": vector, + "k": int(K.max()) + 1, + "num_candidates": num_candidates, + }, + } + res = es.knn_search(index=index_name, body=request) + search_time.append(time.monotonic() - time_before) - nearest_neighbours[query_id]["ids"] = np.array([hits[i]['_source']['external_id'] for i in range(len(hits))]) - nearest_neighbours[query_id]["score"] = np.array([hits[i]['_score'] for i in range(len(hits))]) + query_id = str(external_id_batch[i]) - - return nearest_neighbours, np.mean(search_time) + nearest_neighbours[query_id] = {} + hits = res["hits"]["hits"] -def main( - es: Elasticsearch, - index_name: str, - embeddings_path: pathlib.Path, - ground_truth_path: pathlib.Path, - save_dir: str, - batch_size: int, - d: int, - queries: int, - index_size: int, - K: np.array, - NUM_list: np.array, - EF_CONSTRUCTION_list: np.array, - M_list: np.array, - ): - """ - Builds and searches an Elasticsearch index to compute the performances of ANN indexes - - Parameters: - es (Elasticsearch): Elasticsearch object to build and search on - index_name (str): name of the index to be created and searched on - embeddings_path (pathlib.Path): path to the file containing embeddings to be added to the index and query vectors to search - ground_truth_path (pathlib.Path): path to the file containing ground truth nearest neighbours for evaluation - save_dir (str): directory to save results - batch_size (int): number of embeddings to be added to the index and query vectors to search at a time - d (int): dimension of each embedding - queries (int): number of query vectors - index_size (int): number of embeddings to be added to the index - K (np.array): array of top K nearest neighbours to retrieve - NUM_list (np.array): array of num_candidates values to be considered during search - EF_CONSTRUCTION_list (np.array): array of ef_construction values to be considered during search - M_list (np.array): array of M values to be considered during search - - Returns: - performances (dict): dictionary containing the search performances for various parameter combinations - """ - - performances = {} - - performances["index_size"] = index_size - performances["queries"] = queries - - try: - os.mkdir(save_dir) - print("Save directory created !") - except: - print("Save directory already existing !") - - ground_truth = None - if ground_truth_path.isfile(): - ground_truth = load_data(KNN_file) - - performances["hnsw"] = {} - index_number = 0 - - for num_candidates in NUM_list: - num_candidates = int(num_candidates) - for m in M_list: - m = int(m) - for ef_construction in EF_CONSTRUCTION_list: - ef_construction = int(ef_construction) - - - print(f"****** {index_number}th HNSW index ******") - - perf_file = save_dir + "/hnsw_perf_ncandidates_" + str(num_candidates) + "_m_" + str(m) + "_efconstruction_" + str(ef_construction) + ".json" - KNN_file = save_dir + "/hnsw_ANN_ncandidates_" + str(num_candidates) + "_m_" + str(m) + "_efconstruction_" + str(ef_construction) + ".json" - - if pathlib.Path(perf_file).is_file() and pathlib.Path(KNN_file).is_file(): - hnsw_perf = load_data(perf_file) - hnsw_nearest_neighbours = load_data(KNN_file) - - else : - hnsw_perf = {} - - hnsw_perf["num_candidates"] = num_candidates - - print(f"Creation of the {index_number}th hnsw index") - - hnsw_perf["creation_time"] = create_index(es, index_name, d, ef_construction, m) - - ram_available_before = psutil.virtual_memory().available - - hnsw_perf["building_time"] = build_index(es, index_name, embeddings_path, d, batch_size, index_size) - - hnsw_perf["ram_used"] = (ram_available_before-psutil.virtual_memory().available)/(1024.0**3) - - - print(f"Search in the {index_number}th hsnw index") - - hnsw_nearest_neighbours, hnsw_perf["search_time"] = search_index(es, index_name, embeddings_path, queries, batch_size, K, num_candidates) - - hnsw_perf["metrics"] = compute_recall( - ground_truth, - hnsw_nearest_neighbours, - index_size, - queries, - K, - num_candidates, + nearest_neighbours[query_id]["ids"] = np.array( + [hits[i]["_source"]["external_id"] for i in range(len(hits))] + ) + nearest_neighbours[query_id]["score"] = np.array( + [hits[i]["_score"] for i in range(len(hits))] ) - save_data(perf_file, performances["hnsw"][index_number]) - save_data(KNN_file, hnsw_nearest_neighbours) - index_number = index_number + 1 - - - return performances - - -if __name__ == '__main__': - - warnings.filterwarnings("ignore") - - es = Elasticsearch(hosts='http://elastic:elastic_pswd@localhost:9200', verify_certs=False) - es.info().body - - index_name = "logos" - - ground_truth_path = pathlib.Path("faiss_saves_index_1000_queries_10/exact_KNN.json") - - embeddings_path=pathlib.Path("logos_embeddings_512.hdf5") - batch_size = 512 - d = 512 # dimension of the embeddings of the embeddings_path file - index_size = 1000 - queries = 10 - K = np.array([1,5,10,50,100]) - - NUM_array = np.array([110]) - EF_CONSTRUCTION_list = np.array([100]) - M_list = np.array([8]) - + return nearest_neighbours, np.mean(search_time) - save_dir = "elastic_index_" + str(index_size) + "_queries_" + str(queries) - complete_perf = main( - es, - index_name, - embeddings_path, - ground_truth_path, - save_dir, - batch_size, - d, - queries, - index_size, - K, - NUM_array, - EF_CONSTRUCTION_list, - M_list, +def main( + es: Elasticsearch, + index_name: str, + embeddings_path: pathlib.Path, + ground_truth_path: pathlib.Path, + save_dir: str, + batch_size: int, + d: int, + queries: int, + index_size: int, + K: np.array, + NUM_list: np.array, + EF_CONSTRUCTION_list: np.array, + M_list: np.array, +): + """ + Builds and searches an Elasticsearch index to compute the performances of ANN indexes + + Parameters: + es (Elasticsearch): Elasticsearch object to build and search on + index_name (str): name of the index to be created and searched on + embeddings_path (pathlib.Path): path to the file containing embeddings to be added to the index and query vectors to search + ground_truth_path (pathlib.Path): path to the file containing ground truth nearest neighbours for evaluation + save_dir (str): directory to save results + batch_size (int): number of embeddings to be added to the index and query vectors to search at a time + d (int): dimension of each embedding + queries (int): number of query vectors + index_size (int): number of embeddings to be added to the index + K (np.array): array of top K nearest neighbours to retrieve + NUM_list (np.array): array of num_candidates values to be considered during search + EF_CONSTRUCTION_list (np.array): array of ef_construction values to be considered during search + M_list (np.array): array of M values to be considered during search + + Returns: + performances (dict): dictionary containing the search performances for various parameter combinations + """ + + performances = {} + + performances["index_size"] = index_size + performances["queries"] = queries + + try: + os.mkdir(save_dir) + print("Save directory created !") + except: + print("Save directory already existing !") + + ground_truth = None + if ground_truth_path.isfile(): + ground_truth = load_data(KNN_file) + + performances["hnsw"] = {} + index_number = 0 + + for num_candidates in NUM_list: + num_candidates = int(num_candidates) + for m in M_list: + m = int(m) + for ef_construction in EF_CONSTRUCTION_list: + ef_construction = int(ef_construction) + + print(f"****** {index_number}th HNSW index ******") + + perf_file = ( + save_dir + + "/hnsw_perf_ncandidates_" + + str(num_candidates) + + "_m_" + + str(m) + + "_efconstruction_" + + str(ef_construction) + + ".json" + ) + KNN_file = ( + save_dir + + "/hnsw_ANN_ncandidates_" + + str(num_candidates) + + "_m_" + + str(m) + + "_efconstruction_" + + str(ef_construction) + + ".json" + ) + + if ( + pathlib.Path(perf_file).is_file() + and pathlib.Path(KNN_file).is_file() + ): + hnsw_perf = load_data(perf_file) + hnsw_nearest_neighbours = load_data(KNN_file) + + else: + hnsw_perf = {} + + hnsw_perf["num_candidates"] = num_candidates + + print(f"Creation of the {index_number}th hnsw index") + + hnsw_perf["creation_time"] = create_index( + es, index_name, d, ef_construction, m + ) + + ram_available_before = psutil.virtual_memory().available + + hnsw_perf["building_time"] = build_index( + es, index_name, embeddings_path, d, batch_size, index_size + ) + + hnsw_perf["ram_used"] = ( + ram_available_before - psutil.virtual_memory().available + ) / (1024.0**3) + + print(f"Search in the {index_number}th hsnw index") + + hnsw_nearest_neighbours, hnsw_perf["search_time"] = search_index( + es, + index_name, + embeddings_path, + queries, + batch_size, + K, + num_candidates, + ) + + hnsw_perf["metrics"] = compute_recall( + ground_truth, + hnsw_nearest_neighbours, + index_size, + queries, + K, + num_candidates, + ) + + save_data(perf_file, performances["hnsw"][index_number]) + save_data(KNN_file, hnsw_nearest_neighbours) + index_number = index_number + 1 + + return performances + + +if __name__ == "__main__": + + warnings.filterwarnings("ignore") + + es = Elasticsearch( + hosts="http://elastic:elastic_pswd@localhost:9200", verify_certs=False + ) + es.info().body + + index_name = "logos" + + ground_truth_path = pathlib.Path("faiss_saves_index_1000_queries_10/exact_KNN.json") + + embeddings_path = pathlib.Path("logos_embeddings_512.hdf5") + batch_size = 512 + d = 512 # dimension of the embeddings of the embeddings_path file + index_size = 1000 + queries = 10 + K = np.array([1, 5, 10, 50, 100]) + + NUM_array = np.array([110]) + EF_CONSTRUCTION_list = np.array([100]) + M_list = np.array([8]) + + save_dir = "elastic_index_" + str(index_size) + "_queries_" + str(queries) + + complete_perf = main( + es, + index_name, + embeddings_path, + ground_truth_path, + save_dir, + batch_size, + d, + queries, + index_size, + K, + NUM_array, + EF_CONSTRUCTION_list, + M_list, ) - with open(save_dir+"/elastic_complete_perf.json",'w') as f: - json.dump(complete_perf,f) + with open(save_dir + "/elastic_complete_perf.json", "w") as f: + json.dump(complete_perf, f) diff --git a/logo-ann/benchmarks/ANN_benchmark/faiss_benchmark.py b/logo-ann/benchmarks/ANN_benchmark/faiss_benchmark.py index 3a843715..4a1a8627 100644 --- a/logo-ann/benchmarks/ANN_benchmark/faiss_benchmark.py +++ b/logo-ann/benchmarks/ANN_benchmark/faiss_benchmark.py @@ -13,79 +13,89 @@ from recall_computation import compute_recall from utils import get_embedding, save_data, load_data -def create_index(d:int, exact : bool, m: int=32, efSearch: int=40, efConstruction: int=40): + +def create_index( + d: int, exact: bool, m: int = 32, efSearch: int = 40, efConstruction: int = 40 +): """ Creates an index object. - + Args: d: The dimensionality of the embeddings. exact: A boolean indicating whether to create an exact index or an approximate index. m: The number of connections that each node in the index has to its neighbors. efSearch: The maximum number of elements to visit when searching the graph. efConstruction: The maximum number of elements to visit when adding an element to the graph. - + Returns: A tuple containing the index object and the time it took to create the index. """ start_time = time.monotonic() - if exact : - index = faiss.index_factory(d,"IDMap,Flat") - else : + if exact: + index = faiss.index_factory(d, "IDMap,Flat") + else: index_hnsw = faiss.IndexHNSWFlat(d, m, faiss.METRIC_INNER_PRODUCT) index_hnsw.hnsw.efSearch = efSearch index_hnsw.hnsw.efConstruction = efConstruction index = faiss.IndexIDMap(index_hnsw) - return index, time.monotonic() - start_time + return index, time.monotonic() - start_time -def build_index(index: faiss.swigfaiss.IndexIDMap, embeddings_path: pathlib.Path, batch_size : int, n_vec : int): +def build_index( + index: faiss.swigfaiss.IndexIDMap, + embeddings_path: pathlib.Path, + batch_size: int, + n_vec: int, +): """ Builds an index with embeddings. - + Args: index: The index object. embeddings_path: The path to the file containing the embeddings. batch_size: The number of embeddings to process at a time. n_vec: The total number of embeddings to process. - + Returns: The time it took to add all the embeddings to the index. """ building_time = 0 - data_gen = get_embedding(embeddings_path, batch_size, n_vec) + data_gen = get_embedding(embeddings_path, batch_size, n_vec) for batch in tqdm.tqdm(data_gen): (embedding_batch, external_id_batch) = batch elapsed = time.monotonic() - index.add_with_ids(embedding_batch.astype('float32'),external_id_batch.astype("int64")) + index.add_with_ids( + embedding_batch.astype("float32"), external_id_batch.astype("int64") + ) building_time = building_time + time.monotonic() - elapsed return building_time def search_index( - index: faiss.swigfaiss.IndexIDMap, - embeddings_path: pathlib.Path, - queries: int, + index: faiss.swigfaiss.IndexIDMap, + embeddings_path: pathlib.Path, + queries: int, batch_size: int, K: np.array, - ): +): """ Performs k-nearest neighbor search on a set of queries. - + Args: index: The index object. embeddings_path: The path to the file containing the embeddings. queries: The number of queries to perform. batch_size: The number of queries to process at a time. K: The number of nearest neighbors to retrieve for each query. - + Returns: A tuple containing a dictionary of the nearest neighbors for each query and the mean search time for a query. """ @@ -98,16 +108,18 @@ def search_index( for batch in tqdm.tqdm(data_gen): (embeddings_batch, external_id_batch) = batch - + for i in range(len(embeddings_batch)): time_before = time.monotonic() - res = index.search(np.array([embeddings_batch[i]]).astype('float32'),int(K.max())+1) + res = index.search( + np.array([embeddings_batch[i]]).astype("float32"), int(K.max()) + 1 + ) - search_time.append(time.monotonic()-time_before) + search_time.append(time.monotonic() - time_before) - res = np.moveaxis(res,1,0) + res = np.moveaxis(res, 1, 0) query_id = str(external_id_batch[i]) @@ -116,13 +128,12 @@ def search_index( nearest_neighbours[query_id]["ids"] = res[0][1] nearest_neighbours[query_id]["distances"] = res[0][0] - return nearest_neighbours, np.mean(search_time) def main( exact: bool, - embeddings_path: pathlib.Path, + embeddings_path: pathlib.Path, save_dir: str, batch_size: int, d: int, @@ -132,10 +143,10 @@ def main( M_list: np.array, EF_CONSTRUCTION_list: np.array, EF_SEARCH_list: np.array, - ): +): """ Entry point of the script. Builds an index with embeddings and performs k-nearest neighbor search on a set of queries. - + Args: exact: A boolean indicating whether to create an exact index or an approximate index. embeddings_path: The path to the file containing the embeddings. @@ -180,16 +191,22 @@ def main( index, performances["ground_truth"]["creation_time"] = create_index(d, True) - performances["ground_truth"]["building_time"] = build_index(index, embeddings_path, batch_size, index_size) + performances["ground_truth"]["building_time"] = build_index( + index, embeddings_path, batch_size, index_size + ) values = psutil.virtual_memory() ram_available_after = values.available - performances["ground_truth"]["ram_used"] = (ram_available_before - ram_available_after)/(1024.0 ** 3) + performances["ground_truth"]["ram_used"] = ( + ram_available_before - ram_available_after + ) / (1024.0**3) print("Search in FLAT index to compute ground truth") - ground_truth, performances["ground_truth"]["search_time"] = search_index(index, embeddings_path, queries, batch_size, K) + ground_truth, performances["ground_truth"]["search_time"] = search_index( + index, embeddings_path, queries, batch_size, K + ) save_data(KNN_file, ground_truth) @@ -198,23 +215,44 @@ def main( performances["hnsw"] = {} index_number = 0 - for m in M_list : + for m in M_list: m = int(m) - for ef_construction in EF_CONSTRUCTION_list : + for ef_construction in EF_CONSTRUCTION_list: ef_construction = int(ef_construction) for ef_search in EF_SEARCH_list: ef_search = int(ef_search) print(f"****** {index_number}th HNSW index ******") - perf_file = save_dir + "/hnsw_perf_" + str(m) + "_" + str(ef_construction) + "_" + str(ef_search) + ".json" - KNN_file = save_dir + "/hnsw_ANN_" + str(m) + "_" + str(ef_construction) + "_" + str(ef_search) + ".json" - - if pathlib.Path(perf_file).is_file() and pathlib.Path(KNN_file).is_file(): + perf_file = ( + save_dir + + "/hnsw_perf_" + + str(m) + + "_" + + str(ef_construction) + + "_" + + str(ef_search) + + ".json" + ) + KNN_file = ( + save_dir + + "/hnsw_ANN_" + + str(m) + + "_" + + str(ef_construction) + + "_" + + str(ef_search) + + ".json" + ) + + if ( + pathlib.Path(perf_file).is_file() + and pathlib.Path(KNN_file).is_file() + ): hnsw_perf = load_data(perf_file) ground_truth = load_data(KNN_file) - else : + else: hnsw_perf = {} hnsw_perf["M"] = m @@ -227,23 +265,29 @@ def main( print(f"Creation of the {index_number}th jnsw index") index, hnsw_perf["creation_time"] = create_index( - d, - False, - m, - ef_search, - ef_construction, - ) + d, + False, + m, + ef_search, + ef_construction, + ) - hnsw_perf["building_time"] = build_index(index, embeddings_path, batch_size, index_size) + hnsw_perf["building_time"] = build_index( + index, embeddings_path, batch_size, index_size + ) values = psutil.virtual_memory() ram_available_after = values.available - hnsw_perf["ram_used"] = (ram_available_before - ram_available_after)/(1024.0 ** 3) + hnsw_perf["ram_used"] = ( + ram_available_before - ram_available_after + ) / (1024.0**3) print(f"Search in the {index_number}th hsnw index") - hnsw_nearest_neighbours, hnsw_perf["search_time"] = search_index(index, embeddings_path, queries, batch_size, K) + hnsw_nearest_neighbours, hnsw_perf["search_time"] = search_index( + index, embeddings_path, queries, batch_size, K + ) hnsw_perf["metrics"] = compute_recall( ground_truth, @@ -254,7 +298,7 @@ def main( m, ef_construction, ef_search, - ) + ) save_data(perf_file, hnsw_perf) save_data(KNN_file, hnsw_nearest_neighbours) @@ -264,31 +308,33 @@ def main( return performances -embeddings_path=pathlib.Path("logos_embeddings_512.hdf5") +embeddings_path = pathlib.Path("logos_embeddings_512.hdf5") batch_size = 512 d = 512 # dimension of the embeddings of the embeddings_path file index_size = 1000 queries = 10 # number of queries -K = np.array([1,4,10,50,100]) +K = np.array([1, 4, 10, 50, 100]) -M_array=np.array([5,30]) -efConstruction_array=np.array([16]) -efSearch_array=np.array([40]) +M_array = np.array([5, 30]) +efConstruction_array = np.array([16]) +efSearch_array = np.array([40]) save_dir = "faiss_saves_index_" + str(index_size) + "_queries_" + str(queries) -complete_perf = main(True, - embeddings_path, - save_dir, - batch_size, - d, - queries, - index_size, - K, - M_array, - efConstruction_array, - efSearch_array) - -with open(save_dir+"/faiss_complete_perf.json",'w') as f: - json.dump(complete_perf,f) +complete_perf = main( + True, + embeddings_path, + save_dir, + batch_size, + d, + queries, + index_size, + K, + M_array, + efConstruction_array, + efSearch_array, +) + +with open(save_dir + "/faiss_complete_perf.json", "w") as f: + json.dump(complete_perf, f) diff --git a/logo-ann/benchmarks/ANN_benchmark/recall_computation.py b/logo-ann/benchmarks/ANN_benchmark/recall_computation.py index 81dba6ae..d608db62 100644 --- a/logo-ann/benchmarks/ANN_benchmark/recall_computation.py +++ b/logo-ann/benchmarks/ANN_benchmark/recall_computation.py @@ -11,12 +11,11 @@ def compute_recall( index_size: int, queries: int, K: np.array, - m: int=0, - ef_construction: int=0, - ef_search: int=0, - num_candidates: int=0, - ): - + m: int = 0, + ef_construction: int = 0, + ef_search: int = 0, + num_candidates: int = 0, +): performances = {} @@ -31,7 +30,7 @@ def compute_recall( if num_candidates: performances["num_candidates"] = num_candidates - else : + else: performances["M"] = m performances["ef_construction"] = ef_construction performances["ef_search"] = ef_search @@ -54,13 +53,16 @@ def compute_recall( performances["macro_precision"][k] = 0 performances["micro_recall"][k] = 0 performances["micro_precision"][k] = 0 - + count = 0 for id in ground_truth.keys(): - if count >= queries : break - for k in K : + if count >= queries: + break + for k in K: k = int(k) - positive_neighbours = np.isin(approx_res[id]["ids"][1:k+1],ground_truth[id]["ids"][1:k+1]) + positive_neighbours = np.isin( + approx_res[id]["ids"][1 : k + 1], ground_truth[id]["ids"][1 : k + 1] + ) tp = np.sum(positive_neighbours.astype(int)) tp_micro[k] = tp_micro[k] + tp @@ -68,23 +70,27 @@ def compute_recall( fp = k - tp fp_micro[k] = fp_micro[k] + fp - found_neighbours_among_the_expected_ones = np.isin(ground_truth[id]["ids"][1:k+1],approx_res[id]["ids"][1:k+1]) + found_neighbours_among_the_expected_ones = np.isin( + ground_truth[id]["ids"][1 : k + 1], approx_res[id]["ids"][1 : k + 1] + ) fn = k - (np.sum(found_neighbours_among_the_expected_ones.astype(int))) - fn_micro[k] = fn_micro[k] + fn + fn_micro[k] = fn_micro[k] + fn - performances["macro_recall"][k] = performances["macro_recall"][k] + tp/(tp+fn) - performances["macro_precision"][k] = performances["macro_precision"][k] + tp/(tp+fp) - + performances["macro_recall"][k] = performances["macro_recall"][k] + tp / ( + tp + fn + ) + performances["macro_precision"][k] = performances["macro_precision"][ + k + ] + tp / (tp + fp) count = count + 1 - for k in K: k = int(k) - performances["micro_recall"][k] = tp_micro[k]/(tp_micro[k]+fn_micro[k]) - performances["micro_precision"][k] = tp_micro[k]/(tp_micro[k]+fp_micro[k]) - performances["macro_recall"][k] = performances["macro_recall"][k]/count - performances["macro_precision"][k] = performances["macro_precision"][k]/count + performances["micro_recall"][k] = tp_micro[k] / (tp_micro[k] + fn_micro[k]) + performances["micro_precision"][k] = tp_micro[k] / (tp_micro[k] + fp_micro[k]) + performances["macro_recall"][k] = performances["macro_recall"][k] / count + performances["macro_precision"][k] = performances["macro_precision"][k] / count return performances @@ -94,29 +100,44 @@ def compute_recall( d = 768 index_size = 4371343 queries = 1000 - K = np.array([1,4,10,50,100]) + K = np.array([1, 4, 10, 50, 100]) M = 16 ef_construction = 100 ef_search = 128 num_candidates = 110 - save_dir = "computation_512_recall_index_" + str(index_size) + "_queries_" + str(queries) - - ground_truth = load_data("faiss_saves_index_4371343_queries_1000/512_exact_KNN.json") - #approx_res = load_data("PCA_300_saves_index_4371343_queries_1000/hnsw_ANN_"+str(M)+"_"+str(ef_construction)+"_"+str(ef_search)+".json") - approx_res = load_data("elastic_index_4371343_queries_1000/512_ef_hnsw_ANN_ncandidates_110_m_16_efconstruction_100.json") - - - - complete_perf = compute_recall(ground_truth, - approx_res, - index_size, - queries, - K, - M, - ef_construction, - ef_search, - ) - - with open(save_dir+"/512_num_candidates_"+str(num_candidates)+"_m_"+str(M)+"_ef_construction_"+str(ef_construction)+".json",'w') as f: - json.dump(complete_perf,f) + save_dir = ( + "computation_512_recall_index_" + str(index_size) + "_queries_" + str(queries) + ) + + ground_truth = load_data( + "faiss_saves_index_4371343_queries_1000/512_exact_KNN.json" + ) + # approx_res = load_data("PCA_300_saves_index_4371343_queries_1000/hnsw_ANN_"+str(M)+"_"+str(ef_construction)+"_"+str(ef_search)+".json") + approx_res = load_data( + "elastic_index_4371343_queries_1000/512_ef_hnsw_ANN_ncandidates_110_m_16_efconstruction_100.json" + ) + + complete_perf = compute_recall( + ground_truth, + approx_res, + index_size, + queries, + K, + M, + ef_construction, + ef_search, + ) + + with open( + save_dir + + "/512_num_candidates_" + + str(num_candidates) + + "_m_" + + str(M) + + "_ef_construction_" + + str(ef_construction) + + ".json", + "w", + ) as f: + json.dump(complete_perf, f) diff --git a/logo-ann/benchmarks/ANN_benchmark/redis_benchmark.py b/logo-ann/benchmarks/ANN_benchmark/redis_benchmark.py index 83f221ed..84cce8e5 100644 --- a/logo-ann/benchmarks/ANN_benchmark/redis_benchmark.py +++ b/logo-ann/benchmarks/ANN_benchmark/redis_benchmark.py @@ -16,20 +16,19 @@ from redis.commands.search.query import Query - def create_index( - client : Redis, - embedding_field_name : str, - external_id_field_name : str, - dim : int, - exact : bool, - M : int=16, - EF_CONSTRUCTION : int=200, - EF_RUNTIME : int=10, - ): + client: Redis, + embedding_field_name: str, + external_id_field_name: str, + dim: int, + exact: bool, + M: int = 16, + EF_CONSTRUCTION: int = 200, + EF_RUNTIME: int = 10, +): """ Creates an index with either a flat or HNSW structure. - + Args: client: The redis client object. embedding_field_name: The name of the field containing the embeddings. @@ -39,7 +38,7 @@ def create_index( M: The number of connections that each node in the index has to its neighbors. EF_CONSTRUCTION: The maximum number of elements to visit when adding an element to the graph. EF_RUNTIME: The maximum number of elements to visit when searching the graph. - + Returns: The time it took to create the index. """ @@ -54,58 +53,75 @@ def create_index( except: print("No HNSW Index to drop") - percent = '0' # percentage of index that has been built + percent = "0" # percentage of index that has been built count = 0 - if os.path.exists("percentage_study.txt"): os.remove("percentage_study.txt") + if os.path.exists("percentage_study.txt"): + os.remove("percentage_study.txt") start_time = time.monotonic() - if exact : - schema = (VectorField(embedding_field_name, "FLAT", {"TYPE": "FLOAT32", - "DIM": dim, - "DISTANCE_METRIC": "COSINE"}), - NumericField(external_id_field_name)) + if exact: + schema = ( + VectorField( + embedding_field_name, + "FLAT", + {"TYPE": "FLOAT32", "DIM": dim, "DISTANCE_METRIC": "COSINE"}, + ), + NumericField(external_id_field_name), + ) client.ft("Flat").create_index(schema) client.ft("Flat").config_set("default_dialect", 2) - while percent != '1': + while percent != "1": time.sleep(1) - count = count+1 + count = count + 1 percent = client.ft("Flat").info()["percent_indexed"] - with open("percentage_study.txt","a") as f: - f.write(f"Percentage of Flat indexed {percent} at time {count*20} sec \n") - - else : - schema = (VectorField(embedding_field_name, "HNSW", {"TYPE": "FLOAT32", - "DIM": dim, - "DISTANCE_METRIC": "COSINE", - "M":M, - "EF_CONSTRUCTION": EF_CONSTRUCTION, - "EF_RUNTIME": EF_RUNTIME}), - NumericField(external_id_field_name)) + with open("percentage_study.txt", "a") as f: + f.write( + f"Percentage of Flat indexed {percent} at time {count*20} sec \n" + ) + + else: + schema = ( + VectorField( + embedding_field_name, + "HNSW", + { + "TYPE": "FLOAT32", + "DIM": dim, + "DISTANCE_METRIC": "COSINE", + "M": M, + "EF_CONSTRUCTION": EF_CONSTRUCTION, + "EF_RUNTIME": EF_RUNTIME, + }, + ), + NumericField(external_id_field_name), + ) client.ft("HNSW").create_index(schema) client.ft("HNSW").config_set("default_dialect", 2) - while percent != '1': + while percent != "1": time.sleep(1) - count = count+1 + count = count + 1 percent = client.ft("HNSW").info()["percent_indexed"] - with open("percentage_study.txt","a") as f: - f.write(f"Percentage of HNSW indexed {percent} at time {count*240} sec \n") + with open("percentage_study.txt", "a") as f: + f.write( + f"Percentage of HNSW indexed {percent} at time {count*240} sec \n" + ) - return time.monotonic() - start_time + return time.monotonic() - start_time def build_index( - client : Redis, + client: Redis, embedding_path: pathlib.Path, - embedding_field_name : str, - external_id_field_name : str, - n_vec : int, - batch_size : int, - ): + embedding_field_name: str, + external_id_field_name: str, + n_vec: int, + batch_size: int, +): """ Builds the index by adding the embeddings to the client. - + Args: client: The redis client object. embedding_path: The path to the file containing the embeddings. @@ -113,12 +129,12 @@ def build_index( external_id_field_name: The name of the field containing the external ids. n_vec: The number of embeddings to index. batch_size: The number of embeddings to add to the index at a time. - + Returns: The time it took to build the index. - """ + """ - data_gen = get_embedding(embedding_path, batch_size, n_vec) + data_gen = get_embedding(embedding_path, batch_size, n_vec) offset = 0 search_time = [] for batch in tqdm.tqdm(data_gen): @@ -126,8 +142,15 @@ def build_index( for i in range(len(embeddings_batch)): start_time = time.monotonic() - client.hset(i+offset, mapping = {embedding_field_name: embeddings_batch[i].astype('float32').tobytes(), - external_id_field_name: int(external_id_batch[i])}) + client.hset( + i + offset, + mapping={ + embedding_field_name: embeddings_batch[i] + .astype("float32") + .tobytes(), + external_id_field_name: int(external_id_batch[i]), + }, + ) search_time.append((time.monotonic() - start_time)) offset = offset + len(embeddings_batch) return np.mean(search_time) @@ -135,16 +158,16 @@ def build_index( def search_index( client: Redis, - exact: bool, - embeddings_path: pathlib.Path, - queries: int, + exact: bool, + embeddings_path: pathlib.Path, + queries: int, batch_size: int, - embedding_field_name : str, + embedding_field_name: str, K: np.array, - ): +): """ Searches the index for the given embeddings and returns the nearest neighbors and search times. - + Args: client: The redis client object. exact: A flag indicating whether to use the exact or approximate index. @@ -153,7 +176,7 @@ def search_index( batch_size: The size of each batch of queries. embedding_field_name: The name of the field containing the embeddings. K: The array of number of nearest neighbors to retrieve for each query. - + Returns: nearest_neighbors: A dictionary containing the nearest neighbors for each query. search_time: The mean search time across all queries. @@ -165,57 +188,73 @@ def search_index( data_gen = get_embedding(embeddings_path, batch_size, queries) - if exact : idx_name = "Flat" - else : idx_name = "HNSW" + if exact: + idx_name = "Flat" + else: + idx_name = "HNSW" for batch in tqdm.tqdm(data_gen): (embeddings_batch, external_id_batch) = batch for i in range(len(embeddings_batch)): - query_vector = embeddings_batch[i].astype('float32') + query_vector = embeddings_batch[i].astype("float32") query_id = external_id_batch[i] assert len(redis_conn.execute_command("FT._LIST")) == 1 - q = Query(f'*=>[KNN $k @{embedding_field_name} $vec_param AS dist]').paging(0,101).sort_by(f'dist') - + q = ( + Query(f"*=>[KNN $k @{embedding_field_name} $vec_param AS dist]") + .paging(0, 101) + .sort_by(f"dist") + ) + start_time = time.monotonic() - res = client.ft(idx_name).search(q, query_params = {'k': int(K.max()+1),'vec_param': query_vector.tobytes()}) + res = client.ft(idx_name).search( + q, + query_params={ + "k": int(K.max() + 1), + "vec_param": query_vector.tobytes(), + }, + ) search_time.append(time.monotonic() - start_time) query_id = str(query_id) nearest_neighbours[query_id] = {} - nearest_neighbours[query_id]["ids"] = np.array([int(doc.external_id) for doc in res.docs]) - nearest_neighbours[query_id]["distances"] = np.array([float(doc.dist) for doc in res.docs]) + nearest_neighbours[query_id]["ids"] = np.array( + [int(doc.external_id) for doc in res.docs] + ) + nearest_neighbours[query_id]["distances"] = np.array( + [float(doc.dist) for doc in res.docs] + ) return nearest_neighbours, np.mean(search_time) def evaluate( - embeddings_path: pathlib.Path, + embeddings_path: pathlib.Path, save_dir: str, batch_size: int, client: Redis, d: int, embedding_field_name: str, - external_id_field_name: str, + external_id_field_name: str, queries: int, index_size: int, K: np.array, M_list: np.array, EF_CONSTRUCTION_list: np.array, EF_RUNTIME_list: np.array, - ): +): """ Evaluate the performance of the search index. - + This function tests the performance of the search index by building and querying both a flat index and an HNSW index. The performance of each index is evaluated based on search time and the accuracy of the nearest neighbor search results. The results are saved to the specified directory. - + Args: embeddings_path (pathlib.Path): The path to the file containing the embeddings to index. @@ -267,20 +306,33 @@ def evaluate( print("Creation of FLAT index") - performances["ground_truth"]["creation_time"] = create_index(client, embedding_field_name, external_id_field_name, d, exact = True) - + performances["ground_truth"]["creation_time"] = create_index( + client, embedding_field_name, external_id_field_name, d, exact=True + ) + if client.dbsize() != index_size and client.dbsize() < 4317343: - performances["ground_truth"]["building_time"] = build_index(client, embeddings_path, embedding_field_name, external_id_field_name, index_size, batch_size) + performances["ground_truth"]["building_time"] = build_index( + client, + embeddings_path, + embedding_field_name, + external_id_field_name, + index_size, + batch_size, + ) values = psutil.virtual_memory() ram_available_after = values.available - performances["ground_truth"]["ram_used"] = (ram_available_before - ram_available_after)/(1024.0 ** 3) + performances["ground_truth"]["ram_used"] = ( + ram_available_before - ram_available_after + ) / (1024.0**3) print("Search in FLAT index to compute ground truth") - ground_truth, performances["ground_truth"]["search_time"] = search_index(client, True, embeddings_path, queries, batch_size, embedding_field_name, K) - + ground_truth, performances["ground_truth"]["search_time"] = search_index( + client, True, embeddings_path, queries, batch_size, embedding_field_name, K + ) + save_data(KNN_file, ground_truth) save_data(perf_file, performances["ground_truth"]) @@ -288,23 +340,44 @@ def evaluate( performances["hnsw"] = {} index_number = 0 - for m in M_list : + for m in M_list: m = int(m) - for ef_construction in EF_CONSTRUCTION_list : + for ef_construction in EF_CONSTRUCTION_list: ef_construction = int(ef_construction) for ef_runtime in EF_RUNTIME_list: ef_runtime = int(ef_runtime) print(f"****** {index_number}th HNSW index ******") - perf_file = save_dir + "/hnsw_perf_" + str(m) + "_" + str(ef_construction) + "_" + str(ef_runtime) + ".json" - KNN_file = save_dir + "/hnsw_ANN_" + str(m) + "_" + str(ef_construction) + "_" + str(ef_runtime) + ".json" - - if pathlib.Path(perf_file).is_file() and pathlib.Path(KNN_file).is_file(): + perf_file = ( + save_dir + + "/hnsw_perf_" + + str(m) + + "_" + + str(ef_construction) + + "_" + + str(ef_runtime) + + ".json" + ) + KNN_file = ( + save_dir + + "/hnsw_ANN_" + + str(m) + + "_" + + str(ef_construction) + + "_" + + str(ef_runtime) + + ".json" + ) + + if ( + pathlib.Path(perf_file).is_file() + and pathlib.Path(KNN_file).is_file() + ): hnsw_perf = load_data(perf_file) ground_truth = load_data(KNN_file) - else : + else: hnsw_perf = {} hnsw_perf["M"] = m @@ -317,12 +390,21 @@ def evaluate( ram_available_before = values.available if client.dbsize() != index_size and client.dbsize() < 4317343: - hnsw_perf["bulding_time"] = build_index(client, embeddings_path, embedding_field_name, external_id_field_name, index_size, batch_size) + hnsw_perf["bulding_time"] = build_index( + client, + embeddings_path, + embedding_field_name, + external_id_field_name, + index_size, + batch_size, + ) values = psutil.virtual_memory() ram_available_after = values.available - hnsw_perf["ram_used_for_db"] = (ram_available_before - ram_available_after)/(1024.0 ** 3) + hnsw_perf["ram_used_for_db"] = ( + ram_available_before - ram_available_after + ) / (1024.0**3) values = psutil.virtual_memory() ram_available_before = values.available @@ -330,24 +412,34 @@ def evaluate( print(f"Creation of the {index_number}th jnsw index") hnsw_perf["creation_time"] = create_index( - client, - embedding_field_name, - external_id_field_name, - d, + client, + embedding_field_name, + external_id_field_name, + d, False, m, ef_construction, ef_runtime, - ) + ) values = psutil.virtual_memory() ram_available_after = values.available - hnsw_perf["ram_used"] = (ram_available_before - ram_available_after)/(1024.0 ** 3) + hnsw_perf["ram_used"] = ( + ram_available_before - ram_available_after + ) / (1024.0**3) print(f"Search in the {index_number}th hsnw index") - hnsw_nearest_neighbours, hnsw_perf["search_time"] = search_index(client, False, embeddings_path, queries, batch_size, embedding_field_name, K) + hnsw_nearest_neighbours, hnsw_perf["search_time"] = search_index( + client, + False, + embeddings_path, + queries, + batch_size, + embedding_field_name, + K, + ) hnsw_perf["metrics"] = compute_recall( ground_truth, @@ -358,54 +450,52 @@ def evaluate( m, ef_construction, ef_runtime, - ) + ) save_data(perf_file, hnsw_perf) save_data(KNN_file, hnsw_nearest_neighbours) index_number = index_number + 1 performances["hnsw"][index_number] = hnsw_perf - return performances host = "localhost" port = 6379 -redis_conn = Redis(host = host, port = port) +redis_conn = Redis(host=host, port=port) -embeddings_path=pathlib.Path("logos_embeddings_512.hdf5") +embeddings_path = pathlib.Path("logos_embeddings_512.hdf5") batch_size = 512 d = 512 index_size = 1000 queries = 10 -K = np.array([1,4,10,50,100]) +K = np.array([1, 4, 10, 50, 100]) -M_array=np.array([5,50]) -efConstruction_array=np.array([256]) -efRuntime_array=np.array([128]) +M_array = np.array([5, 50]) +efConstruction_array = np.array([256]) +efRuntime_array = np.array([128]) embedding_field_name = "embedding" external_id_field_name = "external_id" save_dir = "redis_saves_index_" + str(index_size) + "_queries_" + str(queries) -complete_perf = evaluate(embeddings_path, - save_dir, - batch_size, - redis_conn, - d, - embedding_field_name, - external_id_field_name, - queries, - index_size, - K, - M_array, - efConstruction_array, - efRuntime_array, - ) - -with open(save_dir+"/redis_complete_perf.json",'w') as f: - json.dump(complete_perf,f) - - +complete_perf = evaluate( + embeddings_path, + save_dir, + batch_size, + redis_conn, + d, + embedding_field_name, + external_id_field_name, + queries, + index_size, + K, + M_array, + efConstruction_array, + efRuntime_array, +) + +with open(save_dir + "/redis_complete_perf.json", "w") as f: + json.dump(complete_perf, f) diff --git a/logo-ann/benchmarks/ANN_benchmark/utils.py b/logo-ann/benchmarks/ANN_benchmark/utils.py index 701a12d9..fb4cd7a5 100644 --- a/logo-ann/benchmarks/ANN_benchmark/utils.py +++ b/logo-ann/benchmarks/ANN_benchmark/utils.py @@ -4,12 +4,14 @@ import pathlib from more_itertools import chunked + class NumpyArrayEncoder(json.JSONEncoder): def default(self, obj): if isinstance(obj, np.ndarray): return obj.tolist() return json.JSONEncoder.default(self, obj) + def get_embedding(embeddings_path: pathlib.Path, batch_size: int, nb_embeddings: int): """ Get embeddings from an embeddings file. @@ -35,7 +37,9 @@ def get_embedding(embeddings_path: pathlib.Path, batch_size: int, nb_embeddings: embedding_dset = f["embedding"] external_id_dset = f["external_id"] - for slicing in chunked(range(min(len(embedding_dset), nb_embeddings)), batch_size): + for slicing in chunked( + range(min(len(embedding_dset), nb_embeddings)), batch_size + ): slicing = np.array(slicing) mask = external_id_dset[slicing] == 0 @@ -48,12 +52,13 @@ def get_embedding(embeddings_path: pathlib.Path, batch_size: int, nb_embeddings: external_id_dset[slicing][mask], ) + def save_data(file_name: str, data: dict): - with open(file_name, 'w') as f: - json.dump(data,f,cls=NumpyArrayEncoder) + with open(file_name, "w") as f: + json.dump(data, f, cls=NumpyArrayEncoder) + def load_data(file_name: str): - with open(file_name, 'r') as f: + with open(file_name, "r") as f: data = json.load(f) return data - diff --git a/logo-ann/benchmarks/embedding_models_benchmark/main.py b/logo-ann/benchmarks/embedding_models_benchmark/main.py index 0539bf67..4075b463 100644 --- a/logo-ann/benchmarks/embedding_models_benchmark/main.py +++ b/logo-ann/benchmarks/embedding_models_benchmark/main.py @@ -113,14 +113,14 @@ def generate_embeddings_clip( def pairwise_squared_euclidian_distance_numpy(A: np.ndarray) -> np.ndarray: assert len(A.shape) == 2 dot_product = np.dot(A[:, None, :], A[None, :, :].swapaxes(1, 2)).squeeze() - squared_sum = np.sum(A ** 2.0, axis=1, keepdims=True) + squared_sum = np.sum(A**2.0, axis=1, keepdims=True) return squared_sum + squared_sum.transpose() - 2 * dot_product def pairwise_squared_euclidian_distance(A: np.ndarray) -> torch.Tensor: assert len(A.shape) == 2 dot_product = torch.matmul(A[:, None, :], A[None, :, :].swapaxes(1, 2)).squeeze() - squared_sum = torch.sum(A ** 2.0, axis=1, keepdim=True) + squared_sum = torch.sum(A**2.0, axis=1, keepdim=True) return squared_sum + squared_sum.T - 2 * dot_product @@ -189,11 +189,11 @@ def get_distance_matrix_mask( """ It returns a boolean matrix of shape size*size. - mask[n,m] = False If we want to keep this logo for the - computation of the closest neighbours. + mask[n,m] = False If we want to keep this logo for the + computation of the closest neighbours. - The idea is to keep at max max_label_count logos per label and not to take - into account the embedding itself when computeing the closest neighbours. + The idea is to keep at max max_label_count logos per label and not to take + into account the embedding itself when computeing the closest neighbours. """ mask = np.ones((size, size), dtype=bool) diff --git a/logo-ann/dataset/clean_logo_dataset.py b/logo-ann/dataset/clean_logo_dataset.py index a91adf60..a4c50779 100644 --- a/logo-ann/dataset/clean_logo_dataset.py +++ b/logo-ann/dataset/clean_logo_dataset.py @@ -38,7 +38,6 @@ } - TO_REMOVE = { "packaging_fr_barquette-et-film-plastique-a-jeter", "packaging_fr_pensez-au-tri-!", diff --git a/logo-ann/dataset/create_logo_dataset.py b/logo-ann/dataset/create_logo_dataset.py index 12d714bb..f62c8caf 100644 --- a/logo-ann/dataset/create_logo_dataset.py +++ b/logo-ann/dataset/create_logo_dataset.py @@ -52,7 +52,8 @@ def item_group_func(x): def filter_items(items, min_count: int = 0): valid_items = [item for item in items if is_valid_item(item)] for key, group in itertools.groupby( - sorted(valid_items, key=item_group_func), item_group_func, + sorted(valid_items, key=item_group_func), + item_group_func, ): group_items = list(group) if len(group_items) >= min_count: @@ -142,7 +143,8 @@ def save_logo(item, image: Image.Image, output_dir: Path): itertools.groupby( sorted(filtered_items, key=operator.itemgetter("source_image")), operator.itemgetter("source_image"), - ), total=total_groups + ), + total=total_groups, ): if source_image in seen_images: print(f"Skipping {source_image}") diff --git a/logo-ann/generation/02_generate_embeddings.py b/logo-ann/generation/02_generate_embeddings.py index 4c814e78..02f519b3 100644 --- a/logo-ann/generation/02_generate_embeddings.py +++ b/logo-ann/generation/02_generate_embeddings.py @@ -30,7 +30,6 @@ def build_model(model_type: str): def get_output_dim(model_type: str): - """Return the embeddings size according to the model used.""" if model_type == "clip-vit-base-patch16" or model_type == "clip-vit-base-patch32": @@ -57,17 +56,16 @@ def generate_embeddings_iter( min_confidence: float = 0.5, processor: Any = None, ): - """Inputs: - model: name of the specific model used - file_path: path of the hdf5 file containing the data of all the logos - - batch-size: size of each batche of logos embedded at the same time + - batch-size: size of each batche of logos embedded at the same time - device: hardware used to compute the embeddings - - seen_set: set of every logo already embedded in + - seen_set: set of every logo already embedded in - min-confidence: minimum of confidence allowed for a logo to be accepted as one Yield the following outputs: - - embeddings: embeddings of every logo of the yielded batch + - embeddings: embeddings of every logo of the yielded batch - external_id: id of the logo """ @@ -76,7 +74,7 @@ def generate_embeddings_iter( confidence_dset = f["confidence"] external_id_dset = f["external_id"] - embeddings_test=[] + embeddings_test = [] for slicing in chunked(range(len(image_dset)), batch_size): slicing = np.array( @@ -98,41 +96,55 @@ def generate_embeddings_iter( if int(external_id) in seen_set: mask[i] = 0 - if np.all(~mask): # if we only have zeros at this step, we have a batch only with empty data or already seen logos + if np.all( + ~mask + ): # if we only have zeros at this step, we have a batch only with empty data or already seen logos continue images = image_dset[slicing][mask] - with torch.no_grad(): # Preprocess the images to put them into the model - inputs = processor(images=[PIL.Image.fromarray(images[i], mode="RGB").to_device(device) for i in range(min(batch_size,len(images)))], - return_tensors="pt", - padding=True).pixel_values + inputs = processor( + images=[ + PIL.Image.fromarray(images[i], mode="RGB").to_device(device) + for i in range(min(batch_size, len(images))) + ], + return_tensors="pt", + padding=True, + ).pixel_values # Passing logos through the model # We don't have text to pass to the model so we use a (1,1) attention mask - # and use (BOS, EOS) as input for text. - outputs = model(**{'pixel_values':inputs, - 'attention_mask':torch.from_numpy(np.ones((len(images),2), dtype=int)), - 'input_ids':torch.from_numpy(np.ones((len(images),2),dtype=int)*[49406,49407])}) + # and use (BOS, EOS) as input for text. + outputs = model( + **{ + "pixel_values": inputs, + "attention_mask": torch.from_numpy( + np.ones((len(images), 2), dtype=int) + ), + "input_ids": torch.from_numpy( + np.ones((len(images), 2), dtype=int) * [49406, 49407] + ), + } + ) # Getting logo embeddings out of the outputs embeddings = outputs.image_embeds.detach().numpy() if np.any(np.isnan(embeddings)): # checking values are not NaN print("A NaN value was detected, avoiding the loop") continue - + yield (embeddings, external_ids[mask]) + def generate_embedding_from_hdf5( data_gen: Iterable, output_path: pathlib.Path, output_dim: int, count: int ): - """Save the embedding and the external id of each logo (data in data_gen) in an hdf5 file (the output_path). - data_gen: yielded embeddings and external ids of each logo from generate_embeddings_iter - output_path: path of the output hdf5 file - output_dim: dimension of the embeddings (depends on the computer vision model used) - - count: amount of embeddings you want to save + - count: amount of embeddings you want to save """ file_exists = output_path.is_file() @@ -153,7 +165,7 @@ def generate_embedding_from_hdf5( print("Offset: {}".format(offset)) - for (embeddings_batch, external_id_batch) in data_gen: + for embeddings_batch, external_id_batch in data_gen: slicing = slice(offset, offset + len(embeddings_batch)) embedding_dset[slicing] = embeddings_batch external_id_dset[slicing] = external_id_batch diff --git a/logo-ann/generation/03_generate_index.py b/logo-ann/generation/03_generate_index.py index 2e87508a..a48d13ad 100644 --- a/logo-ann/generation/03_generate_index.py +++ b/logo-ann/generation/03_generate_index.py @@ -1,7 +1,7 @@ import argparse import pathlib -import faiss +import faiss from annoy import AnnoyIndex import h5py from more_itertools import chunked @@ -21,6 +21,7 @@ KNN-library: name of the specific KNN library used (annoy, faiss) """ + def generate_embeddings_iter(file_path: pathlib.Path, batch_size: int): with h5py.File(str(file_path), "r") as f: embedding_dset = f["embedding"] @@ -49,21 +50,25 @@ def generate_index_from_hdf5( data_gen = generate_embeddings_iter(file_path, batch_size) index = None - output_path = pathlib.Path(str(output_path)+"_faiss") + output_path = pathlib.Path(str(output_path) + "_faiss") assert not output_path.is_file() for batch in tqdm.tqdm(data_gen): (embeddings_batch, external_id_batch) = batch - for (embedding, external_id) in zip(embeddings_batch, external_id_batch): + for embedding, external_id in zip(embeddings_batch, external_id_batch): if index is None: output_dim = embeddings_batch.shape[-1] - index = faiss.index_factory(output_dim,"IDMap,HNSW") # "IDMap,Flat" "IDMap,HNSW" + index = faiss.index_factory( + output_dim, "IDMap,HNSW" + ) # "IDMap,Flat" "IDMap,HNSW" - index.add_with_ids(np.array([embedding]).astype('float32'),np.array([int(external_id)])) + index.add_with_ids( + np.array([embedding]).astype("float32"), np.array([int(external_id)]) + ) if index is not None: - print("virtual memory",psutil.virtual_memory()) + print("virtual memory", psutil.virtual_memory()) faiss.write_index(index, str(output_path)) @@ -80,5 +85,8 @@ def parse_args(): args = parse_args() assert args.data_path.is_file() generate_index_from_hdf5( - args.data_path, args.output_path, args.batch_size, args.tree_count, + args.data_path, + args.output_path, + args.batch_size, + args.tree_count, ) diff --git a/logo-classifier/annotate_logos.py b/logo-classifier/annotate_logos.py index 36980bdf..2cf89fe7 100644 --- a/logo-classifier/annotate_logos.py +++ b/logo-classifier/annotate_logos.py @@ -6,16 +6,18 @@ import re import tqdm + def modify_hdf5_dataset(hdf5_file: str, changed_logos: dict): changed_ids = set(changed_logos.keys()) - with h5py.File(hdf5_file, 'a') as f: + with h5py.File(hdf5_file, "a") as f: offset = get_offset(f) - ids = f['external_id'][:offset] + ids = f["external_id"][:offset] for i in range(len(ids)): id = str(ids[i]) if id in changed_ids: - f['class'][i] = changed_logos[id] + f["class"][i] = changed_logos[id] + def get_changed_logos(missed_logos_dir: str): classes_str, classes_ids = get_labels(settings.labels_path, []) @@ -24,22 +26,25 @@ def get_changed_logos(missed_logos_dir: str): missed_logos = {} for root, dir, files in os.walk(missed_logos_dir): - splitted = root.split('/') - if splitted[-1] == 'true': + splitted = root.split("/") + if splitted[-1] == "true": predicted = splitted[-3] for logo in files: - logo_id = re.search('\_.*\.', logo).group(0)[1:-1] - missed_logos[logo_id] = int(classes_ids[np.where(classes_str==predicted)][0]) + logo_id = re.search("\_.*\.", logo).group(0)[1:-1] + missed_logos[logo_id] = int( + classes_ids[np.where(classes_str == predicted)][0] + ) true = splitted[-2] return missed_logos -if __name__ == '__main__': - ''' + +if __name__ == "__main__": + """ You can run this script once you have annotated each logo. - * dataset_to_modify: paht to the hdf5 dataset you want to modify by your annotations. + * dataset_to_modify: paht to the hdf5 dataset you want to modify by your annotations. * annotation_dir: directory you used to annotate logos. - ''' + """ dataset_to_modify = "datasets/test-val_dataset.hdf5" annotation_dir = "missed_logos" changed_logos = get_changed_logos(annotation_dir) - modify_hdf5_dataset("datasets/test-val_dataset.hdf5", changed_logos) \ No newline at end of file + modify_hdf5_dataset("datasets/test-val_dataset.hdf5", changed_logos) diff --git a/logo-classifier/dataset.py b/logo-classifier/dataset.py index d651e5af..54e973d6 100644 --- a/logo-classifier/dataset.py +++ b/logo-classifier/dataset.py @@ -7,16 +7,20 @@ from utils import get_offset import numpy as np -class LogoDataset(t_data.Dataset): - '''Class for main Dataset Classes''' - def __init__(self, hdf5_file: Path, transforms: List[str], prohibited_classes: List[int]): - self.file = h5py.File(hdf5_file, 'r') + +class LogoDataset(t_data.Dataset): + """Class for main Dataset Classes""" + + def __init__( + self, hdf5_file: Path, transforms: List[str], prohibited_classes: List[int] + ): + self.file = h5py.File(hdf5_file, "r") offset = get_offset(self.file) - embeddings = self.file['embedding'][:offset] - ids = self.file['external_id'][:offset] + embeddings = self.file["embedding"][:offset] + ids = self.file["external_id"][:offset] assert ids[-1] != 0 ids = ids - classes = self.file['class'][:offset] + classes = self.file["class"][:offset] self.transforms = transforms for prohib_classe in prohibited_classes: embeddings = embeddings[np.where(classes != prohib_classe)] @@ -25,10 +29,10 @@ def __init__(self, hdf5_file: Path, transforms: List[str], prohibited_classes: L self.embeddings = embeddings self.ids = ids self.classes = classes - + def __len__(self): return len(self.embeddings) - + def __getitem__(self, idx): embedding = self.embeddings[idx] for transform in self.transforms: @@ -37,60 +41,73 @@ def __getitem__(self, idx): class_arr[self.classes[idx]] = 1 return embedding, class_arr, self.ids[idx] -def get_datasets(train_path: Path, val_path: Path, test_path: Path, transforms: List[str], prohibited_classes: List[int]): + +def get_datasets( + train_path: Path, + val_path: Path, + test_path: Path, + transforms: List[str], + prohibited_classes: List[int], +): train_dataset = LogoDataset(train_path, transforms, prohibited_classes) val_dataset = LogoDataset(val_path, transforms, prohibited_classes) test_dataset = LogoDataset(test_path, transforms, prohibited_classes) return train_dataset, val_dataset, test_dataset -def get_weights(datasets_list: List[LogoDataset], loader_batch_size: int, num_threads: int): + +def get_weights( + datasets_list: List[LogoDataset], loader_batch_size: int, num_threads: int +): res_list = [] for dataset in datasets_list: amount_dict = {} - data_gen = t_data.DataLoader(dataset=dataset, batch_size=loader_batch_size, num_workers=num_threads) + data_gen = t_data.DataLoader( + dataset=dataset, batch_size=loader_batch_size, num_workers=num_threads + ) print("Starting weight generation loop 1") for data_batch in tqdm.tqdm(data_gen): class_batch = data_batch[1] for classe in class_batch: - classe = torch.where(classe==1)[0] + classe = torch.where(classe == 1)[0] try: - amount_dict[str(classe.item())]+=1 + amount_dict[str(classe.item())] += 1 except KeyError: - amount_dict[str(classe.item())]=1 + amount_dict[str(classe.item())] = 1 weight_dict = {} for classe in amount_dict: if amount_dict[classe] == 0: print("there is no class") - weight_dict[classe] = 1/amount_dict[classe] + weight_dict[classe] = 1 / amount_dict[classe] weight_list = [] print("Starting weight generation loop 2") for data_batch in tqdm.tqdm(data_gen): class_batch = data_batch[1] for classe in class_batch: - classe = torch.where(classe==1)[0] + classe = torch.where(classe == 1)[0] weight_list.append(weight_dict[str(classe.item())]) res_list.append(weight_list) - + return res_list -def get_dataloader(train_path: Path, - val_path: float, - test_path: float, - transforms: List[str]=[], - num_threads: int=6, - loader_batch_size: int=32, - prohibited_classes: list=[], - debugging: bool=False, - test: bool=False, - ): +def get_dataloader( + train_path: Path, + val_path: float, + test_path: float, + transforms: List[str] = [], + num_threads: int = 6, + loader_batch_size: int = 32, + prohibited_classes: list = [], + debugging: bool = False, + test: bool = False, +): """ Returns the three dataloaders for training, validation and test. - + Inputs: train_path: pathlib.Path of the hdf5 train dataset val_path: pathlib.Path of the hdf5 val dataset @@ -103,39 +120,66 @@ def get_dataloader(train_path: Path, # get train, val and test datasets print("transforms", transforms) print("prohibited_classes", prohibited_classes) - train_dataset, valid_dataset, test_dataset = get_datasets(train_path, val_path, test_path, transforms, prohibited_classes) + train_dataset, valid_dataset, test_dataset = get_datasets( + train_path, val_path, test_path, transforms, prohibited_classes + ) # define samplers for train, val and test if debugging: test_weights = get_weights([test_dataset], loader_batch_size, num_threads) - test_sampler = t_data.WeightedRandomSampler(weights=test_weights[0], num_samples=len(test_dataset)) - test_loader = t_data.DataLoader(dataset=test_dataset, batch_size=loader_batch_size, num_workers=num_threads, sampler=test_sampler) + test_sampler = t_data.WeightedRandomSampler( + weights=test_weights[0], num_samples=len(test_dataset) + ) + test_loader = t_data.DataLoader( + dataset=test_dataset, + batch_size=loader_batch_size, + num_workers=num_threads, + sampler=test_sampler, + ) return test_loader, test_loader, test_loader elif test: - test_loader = t_data.DataLoader(dataset=test_dataset, batch_size=loader_batch_size, num_workers=num_threads) + test_loader = t_data.DataLoader( + dataset=test_dataset, batch_size=loader_batch_size, num_workers=num_threads + ) return test_loader, test_loader, test_loader train_weights = get_weights([train_dataset], loader_batch_size, num_threads)[0] - train_sampler = t_data.WeightedRandomSampler(weights=train_weights, num_samples=len(train_dataset)) + train_sampler = t_data.WeightedRandomSampler( + weights=train_weights, num_samples=len(train_dataset) + ) # create dataloader for train, val and test - train_loader = t_data.DataLoader(dataset=train_dataset, batch_size=loader_batch_size, num_workers=num_threads, sampler=train_sampler) - valid_loader = t_data.DataLoader(dataset=valid_dataset, batch_size=loader_batch_size, num_workers=num_threads) - test_loader = t_data.DataLoader(dataset=test_dataset, batch_size=loader_batch_size, num_workers=num_threads) + train_loader = t_data.DataLoader( + dataset=train_dataset, + batch_size=loader_batch_size, + num_workers=num_threads, + sampler=train_sampler, + ) + valid_loader = t_data.DataLoader( + dataset=valid_dataset, batch_size=loader_batch_size, num_workers=num_threads + ) + test_loader = t_data.DataLoader( + dataset=test_dataset, batch_size=loader_batch_size, num_workers=num_threads + ) return train_loader, valid_loader, test_loader + def identity_transform(element: Any): return element -if __name__ == '__main__': + +if __name__ == "__main__": train_path = Path("/home/gabriel/off/logo_classifier/datasets/train_dataset.hdf5") val_path = Path("/home/gabriel/off/logo_classifier/datasets/val_dataset.hdf5") test_path = Path("/home/gabriel/off/logo_classifier/datasets/test_dataset.hdf5") - train_loader, val_loader, test_loader = get_dataloader(train_path, val_path, test_path, ["identity_transform"], 1, 15) - + train_loader, val_loader, test_loader = get_dataloader( + train_path, val_path, test_path, ["identity_transform"], 1, 15 + ) + import numpy as np + train_classes = np.zeros(168) val_classes = np.zeros(168) test_classes = np.zeros(168) @@ -154,9 +198,10 @@ def identity_transform(element: Any): test_classes[classe] += 1 import matplotlib.pyplot as plt + plt.plot(train_classes) plt.show() plt.plot(val_classes) plt.show() plt.plot(test_classes) - plt.show() \ No newline at end of file + plt.show() diff --git a/logo-classifier/settings.py b/logo-classifier/settings.py index c9c22cb0..8f3f5226 100644 --- a/logo-classifier/settings.py +++ b/logo-classifier/settings.py @@ -1 +1 @@ -labels_path = "datasets/class_infos.jsonl" \ No newline at end of file +labels_path = "datasets/class_infos.jsonl" diff --git a/logo-classifier/to_delete.py b/logo-classifier/to_delete.py index f829952f..669c77dc 100644 --- a/logo-classifier/to_delete.py +++ b/logo-classifier/to_delete.py @@ -4,11 +4,11 @@ id = 11 -with h5py.File("datasets/test-val_dataset.hdf5", 'r') as f: +with h5py.File("datasets/test-val_dataset.hdf5", "r") as f: offset = get_offset(f) - new_ids = f['external_id'][:offset] + new_ids = f["external_id"][:offset] new_ids = np.array(new_ids) - new_classes = f['class'][:offset] + new_classes = f["class"][:offset] new_classes = np.array(new_classes) -breakpoint() \ No newline at end of file +breakpoint() diff --git a/logo-classifier/train.py b/logo-classifier/train.py index 8cec66a8..59fb2d7b 100644 --- a/logo-classifier/train.py +++ b/logo-classifier/train.py @@ -11,16 +11,25 @@ import numpy as np import settings + class LinearClassifier(nn.Module): def __init__(self, input_dim, output_dim): super().__init__() self.linear = nn.Linear(input_dim, output_dim) # Simple linear model. - + def forward(self, x): x = self.linear(x) return x -def train(files_dir: str, size_epoch = 10000, epochs: int=1, prohibited_classes: list=[], valid_no_class: bool=False, debugging: bool=False): + +def train( + files_dir: str, + size_epoch=10000, + epochs: int = 1, + prohibited_classes: list = [], + valid_no_class: bool = False, + debugging: bool = False, +): # define device device = torch.device("cuda" if torch.cuda.is_available() else "cpu") @@ -30,29 +39,37 @@ def train(files_dir: str, size_epoch = 10000, epochs: int=1, prohibited_classes: train_path = pathlib.Path("datasets/train_dataset.hdf5") val_path = pathlib.Path("datasets/test-val_dataset.hdf5") test_path = pathlib.Path("datasets/test_dataset.hdf5") - train_loader, val_loader, test_loader = get_dataloader(train_path, val_path, test_path, prohibited_classes = prohibited_classes, debugging = debugging) + train_loader, val_loader, test_loader = get_dataloader( + train_path, + val_path, + test_path, + prohibited_classes=prohibited_classes, + debugging=debugging, + ) # get model - input_dim = 512 # number of features in the input data - output_dim = 168 # number of classes in the target + input_dim = 512 # number of features in the input data + output_dim = 168 # number of classes in the target model = LinearClassifier(input_dim, output_dim) model.to(device) # define loss_function criterion = torch.nn.CrossEntropyLoss() - + # define optimizer and learning rate optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) # define proba function - softmax = nn.Softmax(dim=1) # The softmax function is ONLY used to compute - # scores for the missed_logos.json file. + softmax = nn.Softmax(dim=1) # The softmax function is ONLY used to compute + # scores for the missed_logos.json file. # define lists for metrics y_train = [] y_test = [] y_pred = [] - classes_str, classes_ids = get_labels(settings.labels_path, prohibited_classes=prohibited_classes) + classes_str, classes_ids = get_labels( + settings.labels_path, prohibited_classes=prohibited_classes + ) # training loop for epoch in range(epochs): @@ -68,7 +85,9 @@ def train(files_dir: str, size_epoch = 10000, epochs: int=1, prohibited_classes: loss.backward() optimizer.step() train_loss += loss.item() - ground_truth = torch.tensor([torch.where(classe==1)[0] for classe in labels]) + ground_truth = torch.tensor( + [torch.where(classe == 1)[0] for classe in labels] + ) y_train += ground_truth if count >= size_epoch: break @@ -89,32 +108,81 @@ def train(files_dir: str, size_epoch = 10000, epochs: int=1, prohibited_classes: val_loss += loss.item() _, predicted = torch.max(outputs.data, 1) total += labels.size(0) - ground_truth = torch.tensor([torch.where(classe==1)[0] for classe in labels]) + ground_truth = torch.tensor( + [torch.where(classe == 1)[0] for classe in labels] + ) correct += (predicted == ground_truth.to(device)).sum().item() scores = softmax(outputs) - for indice in torch.where((predicted == ground_truth.to(device))==False)[0].tolist(): - missed_no_logos[str(ids[indice].item())]=[ground_truth[indice].item(), predicted[indice].item(), scores[indice][predicted[indice]].item()] + for indice in torch.where( + (predicted == ground_truth.to(device)) == False + )[0].tolist(): + missed_no_logos[str(ids[indice].item())] = [ + ground_truth[indice].item(), + predicted[indice].item(), + scores[indice][predicted[indice]].item(), + ] - if not valid_no_class: # To compute the metrics, we take out all no_class logos - # (not the ones predicted as no_class but the true no_class) - predicted = predicted[torch.where(ground_truth!=0)] - ground_truth = ground_truth[torch.where(ground_truth!=0)] + if ( + not valid_no_class + ): # To compute the metrics, we take out all no_class logos + # (not the ones predicted as no_class but the true no_class) + predicted = predicted[torch.where(ground_truth != 0)] + ground_truth = ground_truth[torch.where(ground_truth != 0)] y_pred += predicted.tolist() y_test += ground_truth.tolist() # Compute metrics and save model - current_macro_f1 = compute_metrics(y_test, y_pred, classes_ids, classes_str, len(train_loader), len(val_loader), correct, total, train_loss, val_loss, epoch) - + current_macro_f1 = compute_metrics( + y_test, + y_pred, + classes_ids, + classes_str, + len(train_loader), + len(val_loader), + correct, + total, + train_loss, + val_loss, + epoch, + ) + save_best_model = SaveBestModel() - save_best_model(current_macro_f1, missed_no_logos, epoch, model, device, files_dir) + save_best_model( + current_macro_f1, missed_no_logos, epoch, model, device, files_dir + ) -def compute_metrics(y_test, y_pred, classes_ids, classes_str, len_train_loader, len_val_loader, correct, total, val_loss, train_loss, epoch): + +def compute_metrics( + y_test, + y_pred, + classes_ids, + classes_str, + len_train_loader, + len_val_loader, + correct, + total, + val_loss, + train_loss, + epoch, +): # Compute all metrics. Report contains all relevant data. - report = sklearn.metrics.classification_report(y_test, y_pred, labels=classes_ids, target_names=classes_str, zero_division=0) - f1_micro = sklearn.metrics.f1_score(y_true=y_test, y_pred=y_pred, labels = classes_ids, average="micro") - f1_macro = sklearn.metrics.f1_score(y_true=y_test, y_pred=y_pred, labels = classes_ids, average="macro", zero_division=0) - f1_classes = sklearn.metrics.f1_score(y_true=y_test, y_pred=y_pred, labels = classes_ids, average=None, zero_division=0) + report = sklearn.metrics.classification_report( + y_test, y_pred, labels=classes_ids, target_names=classes_str, zero_division=0 + ) + f1_micro = sklearn.metrics.f1_score( + y_true=y_test, y_pred=y_pred, labels=classes_ids, average="micro" + ) + f1_macro = sklearn.metrics.f1_score( + y_true=y_test, + y_pred=y_pred, + labels=classes_ids, + average="macro", + zero_division=0, + ) + f1_classes = sklearn.metrics.f1_score( + y_true=y_test, y_pred=y_pred, labels=classes_ids, average=None, zero_division=0 + ) total_accuracy = 100 * correct / total training_loss = train_loss / len_train_loader validation_loss = val_loss / len_val_loader @@ -126,8 +194,8 @@ def compute_metrics(y_test, y_pred, classes_ids, classes_str, len_train_loader, "total_accuracy": total_accuracy, "training_loss": training_loss, "validation_loss": validation_loss, - } - + } + for i in range(len(f1_classes)): metrics_dict["f1 by class/f1_" + str(classes_str[i])] = f1_classes[i] wandb.log(metrics_dict) @@ -142,42 +210,46 @@ def compute_metrics(y_test, y_pred, classes_ids, classes_str, len_train_loader, plt.imshow(confusion_matrix, cmap=plt.get_cmap("Blues")) ax = plt.gca() - ax.set_xticks(np.arange(-.5, len(classes_str)-1, 1)) - ax.set_yticks(np.arange(-.5, len(classes_str)-1, 1)) + ax.set_xticks(np.arange(-0.5, len(classes_str) - 1, 1)) + ax.set_yticks(np.arange(-0.5, len(classes_str) - 1, 1)) ax.set_xticklabels( classes_str, - size='smaller', - rotation='vertical', - ) + size="smaller", + rotation="vertical", + ) ax.set_yticklabels( classes_str, - size='smaller', - rotation='horizontal', - ) - ax.grid(color='black', linestyle='-', linewidth=1) + size="smaller", + rotation="horizontal", + ) + ax.grid(color="black", linestyle="-", linewidth=1) plt.colorbar() plt.ylabel("Predicted") plt.xlabel("True") - wandb.log({"plot":fig}) + wandb.log({"plot": fig}) return f1_macro class SaveBestModel: """ - Class to save the best model and the wrong_predictions while training. - If the current epoch's macro f1 is better than the previous best, + Class to save the best model and the wrong_predictions while training. + If the current epoch's macro f1 is better than the previous best, then save the model state. """ - def __init__( - self, best_macro_f1=-float('inf') - ): + + def __init__(self, best_macro_f1=-float("inf")): self.best_macro_f1 = best_macro_f1 - + def __call__( - self, current_macro_f1, missed_logos, - epoch, model, device, - dir_path, batch_size = 32 + self, + current_macro_f1, + missed_logos, + epoch, + model, + device, + dir_path, + batch_size=32, ): if current_macro_f1 > self.best_macro_f1: self.best_macro_f1 = current_macro_f1 @@ -186,17 +258,18 @@ def __call__( torch.onnx.export( model, torch.zeros([batch_size, 512]).to(device), - dir_path+'/logos_classifier.onnx', + dir_path + "/logos_classifier.onnx", export_params=True, verbose=True, - input_names = ['embeddings'], - output_names = ['scores_per_classes'], - dynamic_axes = {'embeddings':[0]} - ) - add_jsonl([missed_logos], dir_path+'/missed_logos.json') + input_names=["embeddings"], + output_names=["scores_per_classes"], + dynamic_axes={"embeddings": [0]}, + ) + add_jsonl([missed_logos], dir_path + "/missed_logos.json") + -if __name__ == '__main__': - ''' +if __name__ == "__main__": + """ Script used to train the logos classifier model. Give a name to the wandb project. @@ -205,15 +278,21 @@ def __call__( * size_epoch: amount of batches corresponding to one epoch of the training. * epochs: amount of epochs run for training * prohibited_classes: list of the ids of classes you want to prohibit for training, validation and test - * valid_no_class: True if you want no_class logos to be taking into account for the computation + * valid_no_class: True if you want no_class logos to be taking into account for the computation of metrics during validation step. False else. * debugging: True if you want to run the script by using test dataset as train and val, as it is shorter than the former. False for usual use. Run the script ! - ''' - + """ run = wandb.init(project="with_no_class-logos-classifier") files_dir = run.dir - train(files_dir=files_dir, size_epoch=1680, epochs=200, prohibited_classes=[], valid_no_class=True, debugging=False) + train( + files_dir=files_dir, + size_epoch=1680, + epochs=200, + prohibited_classes=[], + valid_no_class=True, + debugging=False, + ) run.finish() diff --git a/logo-classifier/utils.py b/logo-classifier/utils.py index f9b90abc..9991f43b 100644 --- a/logo-classifier/utils.py +++ b/logo-classifier/utils.py @@ -5,45 +5,53 @@ import json import typing + def get_offset(f: h5py.File) -> int: external_id_dset = f["external_id"] array = external_id_dset[:] non_zero_indexes = np.flatnonzero(array) return int(non_zero_indexes[-1]) + 1 + def get_config(config_file: str): - with open('config.yaml') as f: + with open("config.yaml") as f: return yaml.load(f, Loader=SafeLoader) + def get_labels(labels_path: str, prohibited_classes: list): ids = [] str = [] - with open(labels_path, 'rb') as f: + with open(labels_path, "rb") as f: for row in f: dicti = json.loads(row) - if dicti["id"] in prohibited_classes : continue + if dicti["id"] in prohibited_classes: + continue ids.append(dicti["id"]) str.append(dicti["class"]) return str, ids + def add_jsonl(list_dict_to_append: list, json_file: str): for dict in list_dict_to_append: - with open(json_file, 'w') as f: + with open(json_file, "w") as f: json.dump(dict, f) - f.write('\n') + f.write("\n") + def read_json(json_file: str): - with open(json_file, 'rb') as f: + with open(json_file, "rb") as f: for row in f: yield json.loads(row) + def get_str_labels(query_ids: list, labels_ids: np.array, labels_str: np.array): res = [] for id in query_ids: res.append(labels_str[np.where(labels_ids == id)][0]) return res - + + if __name__ == "__main__": data_gen = read_json("test.jsonl") for dict in data_gen: - print(dict) \ No newline at end of file + print(dict) diff --git a/logo-classifier/visualize_logos.py b/logo-classifier/visualize_logos.py index 36c0bfa8..9259afde 100644 --- a/logo-classifier/visualize_logos.py +++ b/logo-classifier/visualize_logos.py @@ -9,10 +9,12 @@ from requests.exceptions import SSLError, Timeout import PIL from io import BytesIO -import os +import os + + +class ImagePreloader(t_data.Dataset): + """Class for main Dataset Classes""" -class ImagePreloader(t_data.Dataset): - '''Class for main Dataset Classes''' def __init__(self, classes: list, json_dataset: str, missed_logos: dict): self.ids = [] self.source = [] @@ -23,31 +25,34 @@ def __init__(self, classes: list, json_dataset: str, missed_logos: dict): set_ids = set(missed_logos.keys()) data_gen = read_json(json_dataset) for dict in tqdm.tqdm(data_gen): - id = str(dict['logo_id']) - if id in set_ids and (classes == [] or missed_logos[id]['prediction'] in classes): + id = str(dict["logo_id"]) + if id in set_ids and ( + classes == [] or missed_logos[id]["prediction"] in classes + ): self.ids.append(id) - self.source.append(dict['source_img']) - self.predicted.append(missed_logos[id]['prediction']) - self.true.append(missed_logos[id]['truth']) - self.bounding_box.append(dict['bounding_box']) - self.score.append(missed_logos[id]['score']) + self.source.append(dict["source_img"]) + self.predicted.append(missed_logos[id]["prediction"]) + self.true.append(missed_logos[id]["truth"]) + self.bounding_box.append(dict["bounding_box"]) + self.score.append(missed_logos[id]["score"]) def __len__(self): return len(self.ids) - + def __getitem__(self, idx): return ( - self.ids[idx], - self.predicted[idx], - self.true[idx], + self.ids[idx], + self.predicted[idx], + self.true[idx], generate_images( - self.source[idx], + self.source[idx], self.bounding_box[idx], ), - self.score[idx] + self.score[idx], ) + def get_classes(input_json_file): classes_str, classes_ids = get_labels(settings.labels_path, []) classes_str = np.array(classes_str) @@ -59,23 +64,27 @@ def get_classes(input_json_file): for logos_id in missed_logos.keys(): truth, prediction, score = missed_logos[logos_id] - truth, prediction = get_str_labels([truth, prediction], classes_ids, classes_str) + truth, prediction = get_str_labels( + [truth, prediction], classes_ids, classes_str + ) + + res[logos_id] = {"truth": truth, "prediction": prediction, "score": score} - res[logos_id] = {'truth': truth, 'prediction': prediction, 'score': score} - return res + def custom_collate_fn(batch): return batch - + + def generate_images(source_img: str, bounding_box: list): base_url = "https://images.openfoodfacts.org/images/products" image_url = base_url + source_img - try: + try: r = requests.get(image_url) except (RequestConnectionError, SSLError, Timeout): return None - if r.status_code == 404: + if r.status_code == 404: return None try: image = Image.open(BytesIO(r.content)) @@ -83,18 +92,21 @@ def generate_images(source_img: str, bounding_box: list): print(f"This one has a problem : source_img") return None try: - assert np.shape(image)[0]>0 + assert np.shape(image)[0] > 0 except (IndexError, AssertionError): - print(f"This one has a shape problem : source_img and shape = {np.shape(image)}") + print( + f"This one has a shape problem : source_img and shape = {np.shape(image)}" + ) return None y_min, x_min, y_max, x_max = bounding_box height, width = np.shape(image)[0], np.shape(image)[1] - cropped = image.crop((width*x_min, height*y_min, width*x_max, height*y_max)) - image = cropped.resize((224,224)) - if image.mode == 'CMYK': - image = image.convert('RGB') + cropped = image.crop((width * x_min, height * y_min, width * x_max, height * y_max)) + image = cropped.resize((224, 224)) + if image.mode == "CMYK": + image = image.convert("RGB") return image + def save_image(id, predicted, true, image, score): base_dict = "missed_logos/" class_predicted_dict = predicted + "/" @@ -106,27 +118,31 @@ def save_image(id, predicted, true, image, score): os.mkdir(base_dict + class_predicted_dict + true_class_dict) os.mkdir(base_dict + class_predicted_dict + true_class_dict + "true/") os.mkdir(base_dict + class_predicted_dict + true_class_dict + "false/") - if not os.path.isfile(base_dict + class_predicted_dict + true_class_dict + image_title): + if not os.path.isfile( + base_dict + class_predicted_dict + true_class_dict + image_title + ): try: image.save(base_dict + class_predicted_dict + true_class_dict + image_title) except: breakpoint() -if __name__ == '__main__': - ''' + +if __name__ == "__main__": + """ Once the training is run, if you want to check the logos where the model was wrong, run this script. - *missed_logos_file should contain the path of the file created at the same time as the onnx + *missed_logos_file should contain the path of the file created at the same time as the onnx model saving during training. *logos_infos_file should contain the path of a jsonl file where each line is a dictionnary with as this one : {"class": "no_class", "id": 0, ....}. - ''' + """ missed_logos_file = "missed_logos.json" logos_infos_file = "datasets/jsonl_dataset.jsonl" missed_logos = get_classes(missed_logos_file) dataset = ImagePreloader([], logos_infos_file, missed_logos) - dataloader = t_data.DataLoader(dataset, batch_size=32, num_workers=4, collate_fn=custom_collate_fn) + dataloader = t_data.DataLoader( + dataset, batch_size=32, num_workers=4, collate_fn=custom_collate_fn + ) for batch in tqdm.tqdm(dataloader): for id, predicted, true, image, score in batch: save_image(id, predicted, true, image, score) - diff --git a/object_detection/crop_detection/cli/inference_yolo_tflite.py b/object_detection/crop_detection/cli/inference_yolo_tflite.py index 8dd6ade4..8a610364 100644 --- a/object_detection/crop_detection/cli/inference_yolo_tflite.py +++ b/object_detection/crop_detection/cli/inference_yolo_tflite.py @@ -34,10 +34,7 @@ def parse(): help="Path to the image to be processed.", ) parser.add_argument( - "--model-path", - type=str, - default=MODEL_PATH, - help="Path to the .tflite model." + "--model-path", type=str, default=MODEL_PATH, help="Path to the .tflite model." ) parser.add_argument( "--threshold", @@ -46,10 +43,10 @@ def parse(): help="Detection score threshold.", ) parser.add_argument( - "--nms-threshold", - type=float, + "--nms-threshold", + type=float, default=NMS_THRESHOLD, - help="Non-Maximum Suppression threshold." + help="Non-Maximum Suppression threshold.", ) parser.add_argument( "--debug", @@ -58,10 +55,7 @@ def parse(): help="Set debug mode.", ) parser.add_argument( - "--save-path", - type=str, - required=False, - help="Path to save the cropped image." + "--save-path", type=str, required=False, help="Path to save the cropped image." ) return parser.parse_args() diff --git a/object_detection/tensorflow_object_api/object_detection.py b/object_detection/tensorflow_object_api/object_detection.py index 4c795375..e6998ecd 100644 --- a/object_detection/tensorflow_object_api/object_detection.py +++ b/object_detection/tensorflow_object_api/object_detection.py @@ -107,8 +107,8 @@ def __init__(self, graph: tf.Graph, label_map): self.categories = label_map_util.convert_label_map_to_categories( label_map, max_num_classes=1000 ) - self.category_index: CategoryIndex = ( - label_map_util.create_category_index(self.categories) + self.category_index: CategoryIndex = label_map_util.create_category_index( + self.categories ) @classmethod @@ -282,7 +282,7 @@ def iter_images_batch( current_dim = None batch: List[Tuple[str, str]] = [] - for (width, height, barcode, image_id) in iter_image_dimensions(file_path): + for width, height, barcode, image_id in iter_image_dimensions(file_path): key = (barcode, image_id) if key in seen_set: @@ -362,7 +362,8 @@ def parse_args(): parser.add_argument("data_path", type=pathlib.Path) parser.add_argument("--batch-size", type=int, default=16) parser.add_argument( - "--invalid-path", type=pathlib.Path, + "--invalid-path", + type=pathlib.Path, ) return parser.parse_args() diff --git a/object_detection/tensorflow_object_api/string_int_label_map_pb2.py b/object_detection/tensorflow_object_api/string_int_label_map_pb2.py index 7693844d..f338e9a2 100644 --- a/object_detection/tensorflow_object_api/string_int_label_map_pb2.py +++ b/object_detection/tensorflow_object_api/string_int_label_map_pb2.py @@ -144,7 +144,7 @@ (_message.Message,), dict( DESCRIPTOR=_STRINGINTLABELMAPITEM, - __module__="robotoff.ml.object_detection.utils.string_int_label_map_pb2" + __module__="robotoff.ml.object_detection.utils.string_int_label_map_pb2", # @@protoc_insertion_point(class_scope:object_detection.protos.StringIntLabelMapItem) ), ) @@ -155,7 +155,7 @@ (_message.Message,), dict( DESCRIPTOR=_STRINGINTLABELMAP, - __module__="robotoff.ml.object_detection.utils.string_int_label_map_pb2" + __module__="robotoff.ml.object_detection.utils.string_int_label_map_pb2", # @@protoc_insertion_point(class_scope:object_detection.protos.StringIntLabelMap) ), ) diff --git a/ocr_cleaning/multiproc_ocr_cleaning.py b/ocr_cleaning/multiproc_ocr_cleaning.py index 2dad89eb..b083f6c0 100644 --- a/ocr_cleaning/multiproc_ocr_cleaning.py +++ b/ocr_cleaning/multiproc_ocr_cleaning.py @@ -22,6 +22,7 @@ > Found 795609 fr texts out of 2481313 lines (32.1%) > Took 0.6s to write 795609 lines to output (1261237.4 lines/s). """ + import itertools import json import multiprocessing diff --git a/spellcheck/old/SessionState.py b/spellcheck/old/SessionState.py index 77da0737..c76db419 100644 --- a/spellcheck/old/SessionState.py +++ b/spellcheck/old/SessionState.py @@ -18,6 +18,7 @@ >>> session_state.user_name 'Mary' """ + import streamlit.ReportThread as ReportThread from streamlit.server.Server import Server diff --git a/spellcheck/old/evaluation/metrics.py b/spellcheck/old/evaluation/metrics.py index f906650e..12ec817e 100644 --- a/spellcheck/old/evaluation/metrics.py +++ b/spellcheck/old/evaluation/metrics.py @@ -156,16 +156,21 @@ def per_item_ingredients_metrics(original: str, correct: str, prediction: str): predicted_count = len(predicted_ingredients) correct_predicted_count = _matching_tokens_count( - correct_ingredients, predicted_ingredients, + correct_ingredients, + predicted_ingredients, ) original_correct_count = _matching_tokens_count( - original_ingredients, correct_ingredients, + original_ingredients, + correct_ingredients, ) original_predicted_count = _matching_tokens_count( - original_ingredients, predicted_ingredients, + original_ingredients, + predicted_ingredients, ) original_correct_predicted_count = _matching_tokens_count( - original_ingredients, predicted_ingredients, correct_ingredients, + original_ingredients, + predicted_ingredients, + correct_ingredients, ) precision_num = correct_predicted_count - original_correct_predicted_count diff --git a/spellcheck/old/ingredients.py b/spellcheck/old/ingredients.py index e6ae6fde..69eaf443 100644 --- a/spellcheck/old/ingredients.py +++ b/spellcheck/old/ingredients.py @@ -6,6 +6,7 @@ /!\ /!\ /!\ """ + import re from dataclasses import dataclass, field from typing import Iterable, List, Set, Tuple, Dict diff --git a/spellcheck/old/label.py b/spellcheck/old/label.py index 33bffbd0..996234a5 100644 --- a/spellcheck/old/label.py +++ b/spellcheck/old/label.py @@ -50,7 +50,10 @@ def query_items(limit): return [item for item in query] -session = SessionState.get(cursor=0, corrected_items=load_dataset(FR_TEST_SET_PATH),) +session = SessionState.get( + cursor=0, + corrected_items=load_dataset(FR_TEST_SET_PATH), +) # Select an item to label/correct items_to_label = query_items(limit=10000) diff --git a/spellcheck/old/mongo.py b/spellcheck/old/mongo.py index ce7ec147..1e377ac1 100644 --- a/spellcheck/old/mongo.py +++ b/spellcheck/old/mongo.py @@ -1,4 +1,7 @@ from pymongo import MongoClient # Connect to local Mongo DB -products = MongoClient(host="localhost", port=27017,).off.products +products = MongoClient( + host="localhost", + port=27017, +).off.products diff --git a/spellcheck/scripts/argilla/benchmark/add_records.py b/spellcheck/scripts/argilla/benchmark/add_records.py index 98001031..bcb72423 100644 --- a/spellcheck/scripts/argilla/benchmark/add_records.py +++ b/spellcheck/scripts/argilla/benchmark/add_records.py @@ -1,7 +1,7 @@ from typing import Iterable import argilla as rg -import pandas as pd +import pandas as pd from spellcheck.utils import get_logger, get_repo_dir, show_diff @@ -9,7 +9,10 @@ LOGGER = get_logger("INFO") REPO_DIR = get_repo_dir() -DATA_PATH = REPO_DIR / "data/benchmark/additional_products/synthetically_corrected_products.parquet" +DATA_PATH = ( + REPO_DIR + / "data/benchmark/additional_products/synthetically_corrected_products.parquet" +) ARGILLA_DATASET_NAME = "benchmark_v2" ARGILLA_WORKSPACE_NAME = "spellcheck" @@ -18,18 +21,17 @@ def main(): - + df = pd.read_parquet(DATA_PATH) LOGGER.info(f"Features: {df.columns}") dataset = rg.FeedbackDataset.from_argilla( - name=ARGILLA_DATASET_NAME, - workspace=ARGILLA_WORKSPACE_NAME + name=ARGILLA_DATASET_NAME, workspace=ARGILLA_WORKSPACE_NAME ) records = prepare_records( originals=df["ingredients_text"].tolist(), references=df["correction"].tolist(), codes=df["code"].tolist(), - langs=df["lang"].tolist() + langs=df["lang"].tolist(), ) dataset.add_records(records) @@ -38,7 +40,7 @@ def prepare_records( originals: Iterable[str], references: Iterable[str], langs: Iterable[str], - codes: Iterable[int] + codes: Iterable[int], ) -> Iterable[rg.FeedbackRecord]: """Prepare records for Argilla. @@ -55,21 +57,20 @@ def prepare_records( rg.FeedbackRecord( fields={ "original": original, - "url": PRODUCT_URL.format(code) if code else None + "url": PRODUCT_URL.format(code) if code else None, }, suggestions=[ rg.SuggestionSchema( question_name="reference", - value=show_diff(original_text=original, corrected_text=reference) + value=show_diff(original_text=original, corrected_text=reference), ) ], - metadata={ - "lang": lang - } - ) for original, reference, code, lang in zip(originals, references, codes, langs) + metadata={"lang": lang}, + ) + for original, reference, code, lang in zip(originals, references, codes, langs) ] return records -if __name__ =="__main__": +if __name__ == "__main__": main() diff --git a/spellcheck/scripts/argilla/benchmark/deploy_benchmark.py b/spellcheck/scripts/argilla/benchmark/deploy_benchmark.py index ebe1ba24..cf17a9f2 100644 --- a/spellcheck/scripts/argilla/benchmark/deploy_benchmark.py +++ b/spellcheck/scripts/argilla/benchmark/deploy_benchmark.py @@ -11,6 +11,5 @@ if __name__ == "__main__": BenchmarkArgilla.from_parquet(path=BENCHMARK_PATH).deploy( - dataset_name=ARGILLA_DATASET_NAME, - workspace_name=ARGILLA_WORKSPACE_NAME + dataset_name=ARGILLA_DATASET_NAME, workspace_name=ARGILLA_WORKSPACE_NAME ) diff --git a/spellcheck/scripts/argilla/benchmark/extract_benchmark.py b/spellcheck/scripts/argilla/benchmark/extract_benchmark.py index a9ce40b5..34c66b6e 100644 --- a/spellcheck/scripts/argilla/benchmark/extract_benchmark.py +++ b/spellcheck/scripts/argilla/benchmark/extract_benchmark.py @@ -26,18 +26,18 @@ def main(): name=ARGILLA_DATASET_NAME, workspace=ARGILLA_WORKSPACE_NAME, postprocess_map_fn=postprocessing_map_fn, - postprocess_filter_fn=postprocessing_filter_fn + postprocess_filter_fn=postprocessing_filter_fn, ) dataset.to_parquet(BENCHMARK_PATH) def extract_dataset( - name: str, - workspace: str, - postprocess_map_fn: Callable = None, - postprocess_filter_fn: Callable = None - ) -> Dataset: - """Extract the annotated dataset from the deployed Argilla. + name: str, + workspace: str, + postprocess_map_fn: Callable = None, + postprocess_filter_fn: Callable = None, +) -> Dataset: + """Extract the annotated dataset from the deployed Argilla. Args: name (str): Argilla dataset @@ -53,15 +53,17 @@ def extract_dataset( LOGGER.info(f"Dataset: {hf_dataset}") if postprocess_map_fn: return postprocess_dataset( - dataset=hf_dataset, + dataset=hf_dataset, map_function=postprocess_map_fn, - filter_fn=postprocess_filter_fn + filter_fn=postprocess_filter_fn, ) return hf_dataset -def postprocess_dataset(dataset: Dataset, map_function: Callable, filter_fn: Callable) -> Dataset: - """Post-processing the dataset. +def postprocess_dataset( + dataset: Dataset, map_function: Callable, filter_fn: Callable +) -> Dataset: + """Post-processing the dataset. Args: dataset (Dataset): Exported dataset from Argilla @@ -80,7 +82,7 @@ def postprocessing_map_fn(element: Mapping) -> Mapping: """Mapping unction applied to the dataset. Args: - element (Mapping): + element (Mapping): One row of the extracted dataset before processing: ``` @@ -88,8 +90,8 @@ def postprocessing_map_fn(element: Mapping) -> Mapping: 'original': 'Ananas, Ananassaft, Säuerungs - mittel: Citronensäure' 'reference': [ { - 'user_id': 'dfb71753-1187-45e1-8006-629bef2b49e0', - 'value': 'Ananas, Ananassaft, Säuerungsmittel: Citronensäure', + 'user_id': 'dfb71753-1187-45e1-8006-629bef2b49e0', + 'value': 'Ananas, Ananassaft, Säuerungsmittel: Citronensäure', 'status': 'submitted' } ] @@ -101,11 +103,15 @@ def postprocessing_map_fn(element: Mapping) -> Mapping: 'external_id': None 'metadata': '{"lang": "de", "data_origin": "labeled_data"}' ``` - + Returns: Mapping: Post-processed element """ - reference = element["reference"][0]["value"] if element["reference"] else element["reference-suggestion"] + reference = ( + element["reference"][0]["value"] + if element["reference"] + else element["reference-suggestion"] + ) postprocessed_reference = remove_markdown(reference) lang = json.loads(element["metadata"]).get("lang") data_origin = json.loads(element["metadata"]).get("data_origin") @@ -114,24 +120,28 @@ def postprocessing_map_fn(element: Mapping) -> Mapping: "reference": postprocessed_reference, "lang": lang, "data_origin": data_origin, - "is_truncated": 0 if not element["is_truncated"] or element["is_truncated"][0]["value"] == "NO" else 1 + "is_truncated": ( + 0 + if not element["is_truncated"] + or element["is_truncated"][0]["value"] == "NO" + else 1 + ), } - + def postprocessing_filter_fn( - element: Mapping, - status: Literal["submitted", "discarded", "draft"] = "submitted" - ) -> bool: + element: Mapping, status: Literal["submitted", "discarded", "draft"] = "submitted" +) -> bool: """Filter dataset depending on annotation status. Args: - element (Mapping): - One row of the extracted dataset before processing + element (Mapping): + One row of the extracted dataset before processing (One would notice that this data was 'discarded' by the annotator) ``` 'url': 'https://world.openfoodfacts.org/product/5942262001416' - 'original': 'water:snow' + 'original': 'water:snow' 'reference': [{'user_id': 'dfb71753-1187-45e1-8006-629bef2b49e0', 'value': 'water:snow', 'status': 'discarded'}] 'reference-suggestion': 'water:snow' 'reference-suggestion-metadata': {'type': None, 'score': None, 'agent': None} @@ -153,9 +163,8 @@ def postprocessing_filter_fn( def remove_markdown( - text: str, - deleted_element: str = ArgillaConfig.deleted_element -)-> str: + text: str, deleted_element: str = ArgillaConfig.deleted_element +) -> str: """Markdowns were added to the text in Argilla to highlight the difference with the original text. They are removed during the dataset extraction. @@ -165,9 +174,13 @@ def remove_markdown( Returns: str: Post-processed text - """ - text = re.sub("]*)?>" + deleted_element + "<\/mark>", "", text) # # - # if an element was deleted - text = re.sub("<\/?mark(?:\s\w+[^>]*)?>", "", text) # - - + """ + text = re.sub( + "]*)?>" + deleted_element + "<\/mark>", "", text + ) # # - # if an element was deleted + text = re.sub( + "<\/?mark(?:\s\w+[^>]*)?>", "", text + ) # - - return text diff --git a/spellcheck/scripts/argilla/benchmark/update_benchmark.py b/spellcheck/scripts/argilla/benchmark/update_benchmark.py index 6c04de55..4644a7e9 100644 --- a/spellcheck/scripts/argilla/benchmark/update_benchmark.py +++ b/spellcheck/scripts/argilla/benchmark/update_benchmark.py @@ -10,8 +10,8 @@ logging.basicConfig( - level=logging.INFO, - format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", + level=logging.INFO, + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", ) LOGGER = logging.getLogger(__name__) @@ -26,27 +26,24 @@ def main(): push_dataset_updates( previous_name=ARGILLA_DATASET_NAME, workspace=ARGILLA_WORKSPACE_NAME, - new_name=ARGILLA_TEST_DATASET_NAME + new_name=ARGILLA_TEST_DATASET_NAME, ) def update_dataset(name: str, workspace: str) -> None: - """Update exisiting Argilla dataset. + """Update exisiting Argilla dataset. Only suggestions, responses and metadata can be modified. ATTENTION: test the update by creating a new dataset first to ensure you don't delete existing annotations. - Args: + Args: name (str): Dataset name workspace (str): Workspace name """ # Extract previous annotation - dataset = rg.FeedbackDataset.from_argilla( - name=name, - workspace=workspace - ) + dataset = rg.FeedbackDataset.from_argilla(name=name, workspace=workspace) updated_records = update_records(dataset.records) dataset.update_records(updated_records) - + def update_records(records: Iterable[rg.FeedbackRecord]) -> Iterable[rg.FeedbackRecord]: """Update records. @@ -61,32 +58,34 @@ def update_records(records: Iterable[rg.FeedbackRecord]) -> Iterable[rg.Feedback for record in records: original = record.fields.get("original") suggestion = record.suggestions[0].value if record.suggestions else None - response = record.responses[0].values["reference"].value if record.responses else None + response = ( + record.responses[0].values["reference"].value if record.responses else None + ) if suggestion: record.suggestions = [ { "question_name": "reference", "value": show_diff(original, suggestion), - "agent": "gpt-3.5" + "agent": "gpt-3.5", } ] if response: record.responses = [ - { - "values":{ - "reference":{ - "value": show_diff(original, response), - } - }, - "inserted_at": datetime.now(), - "updated_at": datetime.now(), - "status": "submitted" - } - ] + { + "values": { + "reference": { + "value": show_diff(original, response), + } + }, + "inserted_at": datetime.now(), + "updated_at": datetime.now(), + "status": "submitted", + } + ] record.metadata = { - "lang": record.metadata.get("lang"), - "data_origin": _update_metadata(record.metadata["data_origin"]) - } + "lang": record.metadata.get("lang"), + "data_origin": _update_metadata(record.metadata["data_origin"]), + } modified_record.append(record) return modified_record @@ -110,10 +109,7 @@ def push_dataset_updates(previous_name: str, workspace: str, new_name: str) -> N workspace (str): Workspace name new_name (str): New dataset name """ - rg.init( - api_url=os.getenv("ARGILLA_API_URL"), - api_key=os.getenv("ARGILLA_API_KEY") - ) + rg.init(api_url=os.getenv("ARGILLA_API_URL"), api_key=os.getenv("ARGILLA_API_KEY")) dataset = rg.FeedbackDataset( fields=[ @@ -121,30 +117,30 @@ def push_dataset_updates(previous_name: str, workspace: str, new_name: str) -> N rg.TextField(name="original", title="Original", use_markdown=True), ], questions=[ - rg.TextQuestion(name="reference", title="Correct the prediction.", use_markdown=True), + rg.TextQuestion( + name="reference", title="Correct the prediction.", use_markdown=True + ), rg.LabelQuestion( name="is_truncated", title="Is the list of ingredients truncated?", - labels=["YES","NO"], - required=False - ) + labels=["YES", "NO"], + required=False, + ), ], metadata_properties=[ rg.TermsMetadataProperty(name="lang", title="Language"), - rg.TermsMetadataProperty(name="data_origin", title="Origin") + rg.TermsMetadataProperty(name="data_origin", title="Origin"), ], ) # Load previous dataset previous_dataset = rg.FeedbackDataset.from_argilla( - name=previous_name, - workspace=workspace + name=previous_name, workspace=workspace ) updated_records = update_records(previous_dataset.records) dataset.add_records(updated_records) dataset.push_to_argilla(name=new_name, workspace=workspace) - if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/spellcheck/scripts/argilla/dataset/deploy_training.py b/spellcheck/scripts/argilla/dataset/deploy_training.py index 413bb7e7..ece8dfb3 100644 --- a/spellcheck/scripts/argilla/dataset/deploy_training.py +++ b/spellcheck/scripts/argilla/dataset/deploy_training.py @@ -16,14 +16,13 @@ if __name__ == "__main__": - + TrainingDataArgilla.from_dataset( hf_repo="openfoodfacts/spellcheck-dataset", split="train+test", original_feature="text", reference_feature="label", ).deploy( - dataset_name=ARGILLA_DATASET_NAME, + dataset_name=ARGILLA_DATASET_NAME, workspace_name=ARGILLA_WORKSPACE_NAME, ) - \ No newline at end of file diff --git a/spellcheck/scripts/batch/main.py b/spellcheck/scripts/batch/main.py index 81e6b12d..65bac508 100644 --- a/spellcheck/scripts/batch/main.py +++ b/spellcheck/scripts/batch/main.py @@ -18,18 +18,53 @@ def parse() -> argparse.Namespace: - """Parse command line arguments. - """ + """Parse command line arguments.""" parser = argparse.ArgumentParser(description="Spellcheck module.") - parser.add_argument("--data_bucket", type=str, default="robotoff-spellcheck", help="Bucket name.") - parser.add_argument("--pre_data_suffix", type=str, default="data/test_data.parquet", help="Dataset suffix containing the data to be processed.") - parser.add_argument("--post_data_suffix", type=str, default="data/test_processed_data.parquet", help="Dataset suffix containing the processed data.") - parser.add_argument("--model_path", default="openfoodfacts/spellcheck-mistral-7b", type=str, help="HF model path.") - parser.add_argument("--max_model_len", default=1024, type=int, help="Maximum model context length. A lower max context length reduces the memory footprint and accelerate the inference.") - parser.add_argument("--temperature", default=0, type=float, help="Sampling temperature.") - parser.add_argument("--max_tokens", default=1024, type=int, help="Maximum number of tokens to generate.") - parser.add_argument("--quantization", default="fp8", type=str, help="Quantization type.") - parser.add_argument("--dtype", default="auto", type=str, help="Model weights precision. Default corresponds to the modle config (float16 here)") + parser.add_argument( + "--data_bucket", type=str, default="robotoff-spellcheck", help="Bucket name." + ) + parser.add_argument( + "--pre_data_suffix", + type=str, + default="data/test_data.parquet", + help="Dataset suffix containing the data to be processed.", + ) + parser.add_argument( + "--post_data_suffix", + type=str, + default="data/test_processed_data.parquet", + help="Dataset suffix containing the processed data.", + ) + parser.add_argument( + "--model_path", + default="openfoodfacts/spellcheck-mistral-7b", + type=str, + help="HF model path.", + ) + parser.add_argument( + "--max_model_len", + default=1024, + type=int, + help="Maximum model context length. A lower max context length reduces the memory footprint and accelerate the inference.", + ) + parser.add_argument( + "--temperature", default=0, type=float, help="Sampling temperature." + ) + parser.add_argument( + "--max_tokens", + default=1024, + type=int, + help="Maximum number of tokens to generate.", + ) + parser.add_argument( + "--quantization", default="fp8", type=str, help="Quantization type." + ) + parser.add_argument( + "--dtype", + default="auto", + type=str, + help="Model weights precision. Default corresponds to the modle config (float16 here)", + ) return parser.parse_args() @@ -39,7 +74,7 @@ def main(): Original lists of ingredients are stored in a gs bucket before being loaded then processed by the model. The corrected lists of ingredients are then stored back in gs. - We use vLLM to process the batch optimaly. The model is loaded from the Open Food Facts Hugging Face model repository. + We use vLLM to process the batch optimaly. The model is loaded from the Open Food Facts Hugging Face model repository. """ LOGGER.info("Starting batch processing job.") args = parse() @@ -48,35 +83,39 @@ def main(): data = load_gcs(bucket_name=args.data_bucket, suffix=args.pre_data_suffix) LOGGER.info(f"Feature in uploaded data: {data.columns}") if not all(feature in data.columns for feature in FEATURES_VALIDATION): - raise ValueError(f"Data should contain the following features: {FEATURES_VALIDATION}. Current features: {data.columns}") + raise ValueError( + f"Data should contain the following features: {FEATURES_VALIDATION}. Current features: {data.columns}" + ) instructions = [prepare_instruction(text) for text in data["text"]] llm = LLM( - model=args.model_path, - max_model_len=args.max_model_len, + model=args.model_path, + max_model_len=args.max_model_len, dtype=args.dtype, quantization=args.quantization, ) sampling_params = SamplingParams( - temperature=args.temperature, - max_tokens=args.max_tokens + temperature=args.temperature, max_tokens=args.max_tokens ) - LOGGER.info(f"Starting batch inference:\n {llm}.\n\nSampling parameters: {sampling_params}") - data["correction"] = batch_inference(instructions, llm=llm, sampling_params=sampling_params) + LOGGER.info( + f"Starting batch inference:\n {llm}.\n\nSampling parameters: {sampling_params}" + ) + data["correction"] = batch_inference( + instructions, llm=llm, sampling_params=sampling_params + ) LOGGER.info(f"Uploading data to GCS: {args.data_bucket}/{args.post_data_suffix}") # Save DataFrame as Parquet to a temporary file - with tempfile.NamedTemporaryFile(delete=True, suffix='.parquet') as temp_file: + with tempfile.NamedTemporaryFile(delete=True, suffix=".parquet") as temp_file: data.to_parquet(temp_file.name) temp_file_name = temp_file.name upload_gcs( - temp_file_name, - bucket_name=args.data_bucket, - suffix=args.post_data_suffix + temp_file_name, bucket_name=args.data_bucket, suffix=args.post_data_suffix ) LOGGER.info("Batch processing job completed.") + def prepare_instruction(text: str) -> str: """Prepare instruction prompt for fine-tuning and inference. @@ -87,18 +126,14 @@ def prepare_instruction(text: str) -> str: str: Instruction. """ instruction = ( - "###Correct the list of ingredients:\n" - + text - + "\n\n###Correction:\n" + "###Correct the list of ingredients:\n" + text + "\n\n###Correction:\n" ) return instruction def batch_inference( - texts: List[str], - llm: LLM, - sampling_params: SamplingParams - ) -> List[str]: + texts: List[str], llm: LLM, sampling_params: SamplingParams +) -> List[str]: """Process batch of texts with vLLM. Args: @@ -109,7 +144,10 @@ def batch_inference( Returns: List[str]: Processed batch of texts """ - outputs = llm.generate(texts, sampling_params,) + outputs = llm.generate( + texts, + sampling_params, + ) corrections = [output.outputs[0].text for output in outputs] return corrections @@ -118,7 +156,7 @@ def load_gcs(bucket_name: str, suffix: str) -> pd.DataFrame: """Load data from Google Cloud Storage bucket. Args: - bucket_name (str): + bucket_name (str): suffix (str): Path inside the bucket Returns: @@ -145,5 +183,6 @@ def upload_gcs(file_path: str, bucket_name: str, suffix: str) -> None: blob = bucket.blob(suffix) blob.upload_from_filename(filename=file_path) + if __name__ == "__main__": main() diff --git a/spellcheck/scripts/benchmark/create_benchmark.py b/spellcheck/scripts/benchmark/create_benchmark.py index 524f2d41..9b3e0c68 100644 --- a/spellcheck/scripts/benchmark/create_benchmark.py +++ b/spellcheck/scripts/benchmark/create_benchmark.py @@ -13,8 +13,8 @@ LOGGER = logging.getLogger(__name__) logging.basicConfig( - level=logging.getLevelName("INFO"), - format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", + level=logging.getLevelName("INFO"), + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", ) URL = "https://world.openfoodfacts.org/data-quality-warning/en:ingredients-50-percent-unknown.json?page_size={}" @@ -145,9 +145,9 @@ def prepare_data_from_old_fr( spellcheck = Spellcheck( model=OpenAIChatCompletion( prompt_template=Prompt.spellcheck_prompt_template, - system_prompt=SystemPrompt.spellcheck_system_prompt + system_prompt=SystemPrompt.spellcheck_system_prompt, ) - ) + ) prepare_benchmark( benchmark_version=BENCHMARK_VERSION, save_path=BENCHMARK_PATH, diff --git a/spellcheck/scripts/benchmark/create_test_benchmark.py b/spellcheck/scripts/benchmark/create_test_benchmark.py index fb3ccc8a..2147c482 100644 --- a/spellcheck/scripts/benchmark/create_test_benchmark.py +++ b/spellcheck/scripts/benchmark/create_test_benchmark.py @@ -28,24 +28,21 @@ def main(): model=GeminiModel( prompt_template=Prompt.claude_spellcheck_prompt_template, system_prompt=SystemPrompt.spellcheck_system_prompt, - model_name=MODEL_NAME + model_name=MODEL_NAME, ) ) prepare_test_benchmark( labeled_data_path=LABELED_DATA_PATH, spellcheck=spellcheck, save_path=BENCHMARK_PATH, - wait=WAIT + wait=WAIT, ) def prepare_test_benchmark( - labeled_data_path: Path, - spellcheck: Spellcheck, - save_path: Path, - wait: int = 0 + labeled_data_path: Path, spellcheck: Spellcheck, save_path: Path, wait: int = 0 ) -> None: - """Preparation of the test benchmark using labeled data. + """Preparation of the test benchmark using labeled data. This step helps us prompt engineering GPT-3.5/GPT-4 to later augmentte our data. Args: @@ -60,14 +57,11 @@ def prepare_test_benchmark( { "original": original, "reference": reference, - "openai_prediction": spellcheck.correct(original) + "openai_prediction": spellcheck.correct(original), } ) time.sleep(wait) - save_data( - data={"data": output_data}, - save_path=save_path - ) + save_data(data={"data": output_data}, save_path=save_path) LOGGER.info("Test benchmark created and saved.") @@ -89,4 +83,4 @@ def save_data(data: Mapping, save_path: Path) -> None: if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/spellcheck/scripts/benchmark/generate_synthetic_data_for_additional_products.py b/spellcheck/scripts/benchmark/generate_synthetic_data_for_additional_products.py index 82b145d2..b5ca8f6a 100644 --- a/spellcheck/scripts/benchmark/generate_synthetic_data_for_additional_products.py +++ b/spellcheck/scripts/benchmark/generate_synthetic_data_for_additional_products.py @@ -1,4 +1,5 @@ """Append new ingredients lists corrections using LLM to the benchmark.""" + import pandas as pd from spellcheck.spellcheck import Spellcheck @@ -10,8 +11,14 @@ LOGGER = get_logger("INFO") REPO_DIR = get_repo_dir() -ADDITIONAL_PRODUCTS_PATH = REPO_DIR / "data/benchmark/additional_products/extracted_additional_products.parquet" -SYNTHETIC_DATA_PATH = REPO_DIR / "data/benchmark/additional_products/synthetically_corrected_products.parquet" +ADDITIONAL_PRODUCTS_PATH = ( + REPO_DIR + / "data/benchmark/additional_products/extracted_additional_products.parquet" +) +SYNTHETIC_DATA_PATH = ( + REPO_DIR + / "data/benchmark/additional_products/synthetically_corrected_products.parquet" +) MODEL_NAME = "gpt-3.5-turbo" @@ -23,7 +30,7 @@ def main(): model=OpenAIChatCompletion( prompt_template=Prompt.spellcheck_prompt_template, system_prompt=SystemPrompt.spellcheck_system_prompt, - model_name=MODEL_NAME + model_name=MODEL_NAME, ) ) @@ -34,6 +41,7 @@ def main(): # Save df.to_parquet(SYNTHETIC_DATA_PATH) - + + if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/spellcheck/scripts/dags/benchmark_generation.py b/spellcheck/scripts/dags/benchmark_generation.py index 6dcff7ea..20ced0fa 100644 --- a/spellcheck/scripts/dags/benchmark_generation.py +++ b/spellcheck/scripts/dags/benchmark_generation.py @@ -15,7 +15,7 @@ class GenerateBenchmarkPipeline(metaflow.FlowSpec): """A pipeline to create a synthetically generated benchmark using LLM and deploy it to Argilla. - + NOTE: Extracting and processing data directly from the database using DuckDB is considered. Right now, the process is done out of this pipeline. The process involves avoiding duplicates with the training dataset and the existing benchmark. """ @@ -26,28 +26,28 @@ class GenerateBenchmarkPipeline(metaflow.FlowSpec): name="benchmark_path", default="data/benchmark/new_benchmark.parquet", type=str, - help="Path of the benchmark" + help="Path of the benchmark", ) model_name = metaflow.Parameter( name="model_name", default="gpt-3.5-turbo", type=str, - help="LLM name. Need to be the official model endpoint name." + help="LLM name. Need to be the official model endpoint name.", ) argilla_dataset_name = metaflow.Parameter( name="argilla_dataset_name", required=True, type=str, - help="Argilla dataset name during deployment." + help="Argilla dataset name during deployment.", ) argilla_workspace_name = metaflow.Parameter( name="argilla_workspace_name", default="spellcheck", type=str, - help="Argilla workspace name." + help="Argilla workspace name.", ) @metaflow.step @@ -58,25 +58,27 @@ def start(self): # Check if all features are in the dataset for feature in self.features: if feature not in self.benchmark_data.column_names: - raise ValueError(f"The feature {feature} is not in the dataset features: {self.benchmark_data.column_names}") + raise ValueError( + f"The feature {feature} is not in the dataset features: {self.benchmark_data.column_names}" + ) self.next(self.generate_synthetic_corrections) @metaflow.step def generate_synthetic_corrections(self): - """Using a foundational LLM, usually closed-source such as Gemini or OpenAI, generate the correction of + """Using a foundational LLM, usually closed-source such as Gemini or OpenAI, generate the correction of lists of ingredients extracted from the Open Food Facts database.""" spellcheck = Spellcheck( model=OpenAIChatCompletion( prompt_template=Prompt.spellcheck_prompt_template, system_prompt=SystemPrompt.spellcheck_system_prompt, - model_name=self.model_name + model_name=self.model_name, ) ) def process_fn(element: Dict) -> Dict: """Fill the dataset with references, corresponding to corrections, generated with the LLM. Since the dataset can be composed of an existing benchmark complemented with new examples, we ensure to generate corrections - only for new data. + only for new data. Args: element (Dict): Dictionnary element during the dataset mapping. @@ -85,12 +87,14 @@ def process_fn(element: Dict) -> Dict: Dict: Processed dataset element. """ if not element["reference"]: - element["reference"] = spellcheck.correct(element["original"]) #TODO: OpenAI processing (tqdm) not showing + element["reference"] = spellcheck.correct( + element["original"] + ) # TODO: OpenAI processing (tqdm) not showing return element self.processed_benchmark = self.benchmark_data.map( - process_fn, - batched=False, # Really important to set it as false since the process function doesn't consider lists but singular texts. + process_fn, + batched=False, # Really important to set it as false since the process function doesn't consider lists but singular texts. ) self.next(self.push_to_argilla) @@ -102,14 +106,14 @@ def push_to_argilla(self): references=self.processed_benchmark["reference"], metadata=[ { - "lang": element["lang"], + "lang": element["lang"], "code": element["code"], - } + } for element in self.processed_benchmark - ] + ], ).deploy( dataset_name=self.argilla_dataset_name, - workspace_name=self.argilla_workspace_name + workspace_name=self.argilla_workspace_name, ) self.next(self.end) diff --git a/spellcheck/scripts/dags/data_processing.py b/spellcheck/scripts/dags/data_processing.py index f0107bf9..4de035d0 100644 --- a/spellcheck/scripts/dags/data_processing.py +++ b/spellcheck/scripts/dags/data_processing.py @@ -9,29 +9,29 @@ LOGGER = get_logger("INFO") + class DataProcessing(FlowSpec): - """Processing pipeline to modify the Spellcheck training dataset algorithmically. - """ + """Processing pipeline to modify the Spellcheck training dataset algorithmically.""" dataset_hf_repo = Parameter( name="dataset_hf_repo", default="openfoodfacts/spellcheck-dataset", type=str, - help="Dataset id stored in the Hugging Face OFF repository." + help="Dataset id stored in the Hugging Face OFF repository.", ) dataset_revision = Parameter( name="dataset_revision", default="v3", type=str, - help="Dataset revision indicating the version for processing. Default to v3." + help="Dataset revision indicating the version for processing. Default to v3.", ) dataset_version = Parameter( name="dataset_version", type=str, required=True, - help="New processed dataset version." + help="New processed dataset version.", ) dataset_split = Parameter( @@ -47,15 +47,17 @@ class DataProcessing(FlowSpec): type=float, required=False, default=0.1, - help="Dataset test split size used during push_to_hub." + help="Dataset test split size used during push_to_hub.", ) @step def start(self): """Load dataset""" if self.dataset_split not in ["train", "test", "train+test"]: - raise ValueError("Invalid value for dataset_split. Should be 'train', 'test', or 'train+test'.") - + raise ValueError( + "Invalid value for dataset_split. Should be 'train', 'test', or 'train+test'." + ) + self.dataset = load_dataset( self.dataset_hf_repo, revision=self.dataset_revision, @@ -63,7 +65,7 @@ def start(self): ) LOGGER.info(f"Dataset loaded:\n{self.dataset}") self.next(self.process) - + @step def process(self): """Process dataset.""" @@ -78,12 +80,10 @@ def process_fn(sample: Mapping) -> Mapping: (Mapping): Processed batch. """ processed_labels = DataProcessor.align_oe( - references=sample["text"], - texts=sample["label"] + references=sample["text"], texts=sample["label"] ) processed_labels = DataProcessor.align_whitespace_percentage( - references=sample["text"], - texts=processed_labels + references=sample["text"], texts=processed_labels ) sample["label"] = processed_labels return sample @@ -101,7 +101,8 @@ def end(self): ).push_to_hub( repo_id=self.dataset_hf_repo, commit_message=self.dataset_version, - commit_description="Metaflow run id:" + current.run_id, # Store Metaflow run id for traceability + commit_description="Metaflow run id:" + + current.run_id, # Store Metaflow run id for traceability ) LOGGER.info("Data processing finished succesfully.") diff --git a/spellcheck/scripts/dags/extract_from_argilla.py b/spellcheck/scripts/dags/extract_from_argilla.py index 8e24829a..2a74da2e 100644 --- a/spellcheck/scripts/dags/extract_from_argilla.py +++ b/spellcheck/scripts/dags/extract_from_argilla.py @@ -11,28 +11,28 @@ class SpellcheckExtractionFromArgillaPipeline(metaflow.FlowSpec): name="status", help="Which status to extract from Argilla. Can be 'submitted', 'pending', 'draft', 'discarded'", default="submitted", - multiple=True # In the CLI: ... --status submitted --status pending + multiple=True, # In the CLI: ... --status submitted --status pending ) dataset_hf_repo = metaflow.Parameter( name="dataset_hf_repo", required=True, type=str, - help="Hugging Face dataset repo id." + help="Hugging Face dataset repo id.", ) dataset_revision = metaflow.Parameter( name="dataset_revision", type=str, required=True, - help="Uploaded dataset branch. Each branch is a revision containing the different version of the dataset." + help="Uploaded dataset branch. Each branch is a revision containing the different version of the dataset.", ) dataset_version = metaflow.Parameter( name="dataset_version", type=str, required=True, - help="New version of the dataset as a commit in the main and revision branch." + help="New version of the dataset as a commit in the main and revision branch.", ) dataset_test_size = metaflow.Parameter( @@ -40,14 +40,14 @@ class SpellcheckExtractionFromArgillaPipeline(metaflow.FlowSpec): type=float, required=False, default=0, - help="Dataset test split size used during push_to_hub. If 0, the entire dataset is labeled under 'train', meaning there is no split." + help="Dataset test split size used during push_to_hub. If 0, the entire dataset is labeled under 'train', meaning there is no split.", ) argilla_dataset_name = metaflow.Parameter( name="argilla_dataset_name", type=str, required=True, - help="Dataset to extract from Argilla." + help="Dataset to extract from Argilla.", ) argilla_workspace_name = metaflow.Parameter( @@ -55,14 +55,14 @@ class SpellcheckExtractionFromArgillaPipeline(metaflow.FlowSpec): type=str, required=False, default="spellcheck", - help="Argilla workspace name. Default to 'spellcheck'." + help="Argilla workspace name. Default to 'spellcheck'.", ) argilla_dataset_local_path = metaflow.Parameter( name="local_path", type=str, required=True, - help="Local path to store the argilla dataset during the pipeline process." + help="Local path to store the argilla dataset during the pipeline process.", ) deploy_to_hf = metaflow.Parameter( @@ -70,7 +70,7 @@ class SpellcheckExtractionFromArgillaPipeline(metaflow.FlowSpec): type=bool, required=False, default=False, - help="Whether to push to dataset to HuggingFace." + help="Whether to push to dataset to HuggingFace.", ) additional_commit_info = metaflow.Parameter( @@ -78,7 +78,7 @@ class SpellcheckExtractionFromArgillaPipeline(metaflow.FlowSpec): type=str, required=False, default="", - help="Whether to add a commit description to the commit." + help="Whether to add a commit description to the commit.", ) @metaflow.step @@ -88,8 +88,7 @@ def start(self): @metaflow.step def extract_from_argilla(self): - """Argilla extraction step. Takes the status as input the user wants to extract. - """ + """Argilla extraction step. Takes the status as input the user wants to extract.""" print("Start extraction from Argilla.") argilla_dataset = SpellcheckExtraction( dataset_name=self.argilla_dataset_name, @@ -105,21 +104,24 @@ def extract_from_argilla(self): @metaflow.step def push_to_hf(self): """Conditional step. - + Push the extracted dataset to a HuggingFace dataset repo. This step takes the version of the dataset as a commit, the revision as a branch where to push the commit. By default, any modification is pushed to the main branch along the revision branch. """ if self.deploy_to_hf: - print(f"Start deploying to HuggingFace. \ + print( + f"Start deploying to HuggingFace. \ Repo_id: {self.dataset_hf_repo} - \ Revision: {self.dataset_revision} - \ Data version: {self.dataset_version}" ) dataset = Dataset.from_parquet(self.argilla_dataset_local_path) commit_description = ( - "metaflow run id: " + metaflow.current.run_id - + "\n\n" + self.additional_commit_info + "metaflow run id: " + + metaflow.current.run_id + + "\n\n" + + self.additional_commit_info ) # Check if test_size applied @@ -132,7 +134,7 @@ def push_to_hf(self): revision="main", commit_message=self.dataset_version, commit_description=commit_description, - ) + ) # Push to revision branch dataset.push_to_hub( repo_id=self.dataset_hf_repo, @@ -142,7 +144,7 @@ def push_to_hf(self): ) else: print(f"No deployment to HF. Condition is {self.deploy_to_hf}") - + self.next(self.end) @metaflow.step diff --git a/spellcheck/scripts/dags/training/pretraining_finetuning.py b/spellcheck/scripts/dags/training/pretraining_finetuning.py index d7090145..546ae848 100644 --- a/spellcheck/scripts/dags/training/pretraining_finetuning.py +++ b/spellcheck/scripts/dags/training/pretraining_finetuning.py @@ -14,7 +14,7 @@ class TrainingPipeline(metaflow.FlowSpec): - """Spellcheck training pipeline. + """Spellcheck training pipeline. Model can either be trained locally, or on the cloud. """ @@ -63,23 +63,25 @@ class TrainingPipeline(metaflow.FlowSpec): "evaluation_data_version", required=True, type=str, - help="Version of the evaluation dataset used during training." + help="Version of the evaluation dataset used during training.", ) @metaflow.step def start(self): - """Load all parameters from config file used during training. - """ + """Load all parameters from config file used during training.""" import comet_ml + # Create experiment in CometML and log information before starting the training job experiment = comet_ml.Experiment() experiment.add_tags(list(self.experiment_tags)) - experiment.log_parameters({ - "metaflow_run_id": metaflow.current.run_id, - "training_data_version": self.training_data_version, - "evaluation_data_version": self.evaluation_data_version, - "pretraining_data_version": self.pretraining_data_version, - }) + experiment.log_parameters( + { + "metaflow_run_id": metaflow.current.run_id, + "training_data_version": self.training_data_version, + "evaluation_data_version": self.evaluation_data_version, + "pretraining_data_version": self.pretraining_data_version, + } + ) self.experiment_key = experiment.get_key() experiment.end() self.next(self.pretrain) @@ -87,38 +89,48 @@ def start(self): @metaflow.step def pretrain(self): """Pre-training step. - + Use Sagemaker Training Job to package and run the training script in production. """ from sagemaker.huggingface import HuggingFace from omegaconf import OmegaConf - + LOGGER.info(f"Configuration file used: {self.pretraining_conf_path}") self.pretraining_conf = OmegaConf.load(self.pretraining_conf_path) - hyperparameters = OmegaConf.to_container(self.pretraining_conf.hyperparameters, resolve=True) # Transform DictConfig to Dict + hyperparameters = OmegaConf.to_container( + self.pretraining_conf.hyperparameters, resolve=True + ) # Transform DictConfig to Dict LOGGER.info(f"Configs: {self.pretraining_conf}") # Prepare Sagemaker estimator estimator = HuggingFace( - role= os.getenv("SAGEMAKER_ROLE"), # Iam role used in training job to access AWS ressources, e.g. S3 - hyperparameters= hyperparameters, # hyperparameters used for the training job - environment={ # environment variables used during training - "COMET_PROJECT_NAME": os.getenv("COMET_PROJECT_NAME"), # comet project name - "COMET_API_KEY": os.getenv("COMET_API_KEY"), - "COMET_EXPERIMENT_KEY": self.experiment_key, - "HF_TOKEN": os.getenv("HF_TOKEN"), # required by some models, such as llama-3 or Mistral - "S3_MODEL_URI": self.pretraining_conf.estimator.output_path, # the uri where the model artifact is stored is actually not in the SM_TRAINING_JOB environment variables. Let's add it. + role=os.getenv( + "SAGEMAKER_ROLE" + ), # Iam role used in training job to access AWS ressources, e.g. S3 + hyperparameters=hyperparameters, # hyperparameters used for the training job + environment={ # environment variables used during training + "COMET_PROJECT_NAME": os.getenv( + "COMET_PROJECT_NAME" + ), # comet project name + "COMET_API_KEY": os.getenv("COMET_API_KEY"), + "COMET_EXPERIMENT_KEY": self.experiment_key, + "HF_TOKEN": os.getenv( + "HF_TOKEN" + ), # required by some models, such as llama-3 or Mistral + "S3_MODEL_URI": self.pretraining_conf.estimator.output_path, # the uri where the model artifact is stored is actually not in the SM_TRAINING_JOB environment variables. Let's add it. }, - **self.pretraining_conf.estimator, + **self.pretraining_conf.estimator, ) # Run training job - estimator.fit(wait=True) # Wait for the pipeline. No need for inputs since data doesn't come from S3. - + estimator.fit( + wait=True + ) # Wait for the pipeline. No need for inputs since data doesn't come from S3. + # Log Sagemaker training information into metaflow after training job self.sagemaker_pretraining_job_id = estimator.latest_training_job.job_name self.pretrained_model_artifact_uri = ( - self.pretraining_conf.estimator.output_path + self.pretraining_conf.estimator.output_path + self.sagemaker_training_job_id + "output/model/" ) @@ -127,7 +139,7 @@ def pretrain(self): @metaflow.step def train(self): """Training step. - + Use Sagemaker Training Job to package and run the training script in production. """ from sagemaker.huggingface import HuggingFace @@ -135,45 +147,61 @@ def train(self): LOGGER.info(f"Configuration file used: {self.training_conf_path}") self.training_conf = OmegaConf.load(self.training_conf_path) - hyperparameters = OmegaConf.to_container(self.training_conf.hyperparameters, resolve=True) # Transform DictConfig to Dict + hyperparameters = OmegaConf.to_container( + self.training_conf.hyperparameters, resolve=True + ) # Transform DictConfig to Dict LOGGER.info(f"Configs: {self.training_conf}") # Prepare Sagemaker estimator estimator = HuggingFace( - role= os.getenv("SAGEMAKER_ROLE"), # Iam role used in training job to access AWS ressources, e.g. S3 - hyperparameters= hyperparameters, # hyperparameters used for the training job - environment={ # environment variables used during training - "COMET_PROJECT_NAME": os.getenv("COMET_PROJECT_NAME"), # comet project name - "COMET_API_KEY": os.getenv("COMET_API_KEY"), - "COMET_EXPERIMENT_KEY": self.experiment_key, - "HF_TOKEN": os.getenv("HF_TOKEN"), # required by some models, such as llama-3 or Mistral - "S3_MODEL_URI": self.training_conf.estimator.output_path, # the uri where the model artifact is stored is actually not in the SM_TRAINING_JOB environment variables. Let's add it. - "S3_EVALUATION_URI": self.training_conf.additional_conf.s3_evaluation_uri, # s3 uri where evaluation prediction are stored + role=os.getenv( + "SAGEMAKER_ROLE" + ), # Iam role used in training job to access AWS ressources, e.g. S3 + hyperparameters=hyperparameters, # hyperparameters used for the training job + environment={ # environment variables used during training + "COMET_PROJECT_NAME": os.getenv( + "COMET_PROJECT_NAME" + ), # comet project name + "COMET_API_KEY": os.getenv("COMET_API_KEY"), + "COMET_EXPERIMENT_KEY": self.experiment_key, + "HF_TOKEN": os.getenv( + "HF_TOKEN" + ), # required by some models, such as llama-3 or Mistral + "S3_MODEL_URI": self.training_conf.estimator.output_path, # the uri where the model artifact is stored is actually not in the SM_TRAINING_JOB environment variables. Let's add it. + "S3_EVALUATION_URI": self.training_conf.additional_conf.s3_evaluation_uri, # s3 uri where evaluation prediction are stored }, - **self.training_conf.estimator, + **self.training_conf.estimator, ) # Run training job - estimator.fit(wait=True, inputs={"model": self.pretrained_model_artifact_uri}) # Add previous model - + estimator.fit( + wait=True, inputs={"model": self.pretrained_model_artifact_uri} + ) # Add previous model + # Log Sagemaker training information into metaflow after training job self.sagemaker_training_job_id = estimator.latest_training_job.job_name - self.model_artifact_uri = self.training_conf.estimator.output_path + self.sagemaker_training_job_id + self.model_artifact_uri = ( + self.training_conf.estimator.output_path + self.sagemaker_training_job_id + ) self.evaluation_uri = os.path.join( self.training_conf.additional_conf.s3_evaluation_uri, - "evaluation-" + self.sagemaker_training_job_id + "evaluation-" + self.sagemaker_training_job_id, ) self.next(self.human_evaluation) @metaflow.step def human_evaluation(self): """Conditional step: - Push model predictions against the Benchmark to Argilla. + Push model predictions against the Benchmark to Argilla. Evaluation dataset was generated during the training step on the Sagemaker instance. """ if self.do_human_eval: # Get the latest experiment - self.argilla_dataset_name = self.training_conf.estimator.base_job_name + "-exp-key-" + self.experiment_key + self.argilla_dataset_name = ( + self.training_conf.estimator.base_job_name + + "-exp-key-" + + self.experiment_key + ) BenchmarkEvaluationArgilla.from_s3(self.evaluation_uri).deploy( dataset_name=self.argilla_dataset_name ) diff --git a/spellcheck/scripts/dags/training/training.py b/spellcheck/scripts/dags/training/training.py index 495e1210..7f5a1212 100644 --- a/spellcheck/scripts/dags/training/training.py +++ b/spellcheck/scripts/dags/training/training.py @@ -14,7 +14,7 @@ class TrainingPipeline(metaflow.FlowSpec): - """Spellcheck training pipeline. + """Spellcheck training pipeline. Model can either be trained locally, or on the cloud. """ @@ -50,22 +50,24 @@ class TrainingPipeline(metaflow.FlowSpec): "evaluation_data_version", required=True, type=str, - help="Version of the evaluation dataset used during training." + help="Version of the evaluation dataset used during training.", ) @metaflow.step def start(self): - """Load all parameters from config file used during training. - """ + """Load all parameters from config file used during training.""" import comet_ml + # Create experiment in CometML and log information before starting the training job experiment = comet_ml.Experiment() experiment.add_tags(list(self.experiment_tags)) - experiment.log_parameters({ - "metaflow_run_id": metaflow.current.run_id, - "training_data_version": self.training_data_version, - "evaluation_data_version": self.evaluation_data_version, - }) + experiment.log_parameters( + { + "metaflow_run_id": metaflow.current.run_id, + "training_data_version": self.training_data_version, + "evaluation_data_version": self.evaluation_data_version, + } + ) self.experiment_key = experiment.get_key() experiment.end() self.next(self.train) @@ -73,7 +75,7 @@ def start(self): @metaflow.step def train(self): """Training step. - + Use Sagemaker Training Job to package and run the training script in production. """ from sagemaker.huggingface import HuggingFace @@ -82,60 +84,80 @@ def train(self): LOGGER.info(f"Configuration file used: {self.training_conf_path}") self.training_conf = OmegaConf.load(self.training_conf_path) - hyperparameters = OmegaConf.to_container(self.training_conf.hyperparameters, resolve=True) # Transform DictConfig to Dict + hyperparameters = OmegaConf.to_container( + self.training_conf.hyperparameters, resolve=True + ) # Transform DictConfig to Dict LOGGER.info(f"Configs: {self.training_conf}") # CometML - experiment = comet_ml.ExistingExperiment(previous_experiment=self.experiment_key) + experiment = comet_ml.ExistingExperiment( + previous_experiment=self.experiment_key + ) experiment.log_parameters(hyperparameters) - experiment.log_code(file_name=os.path.join( - self.training_conf.estimator.source_dir, - self.training_conf.estimator.entry_point - )) + experiment.log_code( + file_name=os.path.join( + self.training_conf.estimator.source_dir, + self.training_conf.estimator.entry_point, + ) + ) # Save config as code experiment.log_code(file_name=self.training_conf_path) # Prepare Sagemaker estimator estimator = HuggingFace( - role= os.getenv("SAGEMAKER_ROLE"), # Iam role used in training job to access AWS ressources, e.g. S3 - hyperparameters= hyperparameters, # hyperparameters used for the training job - environment={ # environment variables used during training - "COMET_PROJECT_NAME": os.getenv("COMET_PROJECT_NAME"), # comet project name - "COMET_API_KEY": os.getenv("COMET_API_KEY"), - "COMET_EXPERIMENT_KEY": self.experiment_key, - "HF_TOKEN": os.getenv("HF_TOKEN"), # required by some models, such as llama-3 or Mistral - "S3_MODEL_URI": self.training_conf.estimator.output_path, # the uri where the model artifact is stored is actually not in the SM_TRAINING_JOB environment variables. Let's add it. - # "S3_EVALUATION_URI": self.training_conf.additional_conf.s3_evaluation_uri, # s3 uri where evaluation prediction are stored + role=os.getenv( + "SAGEMAKER_ROLE" + ), # Iam role used in training job to access AWS ressources, e.g. S3 + hyperparameters=hyperparameters, # hyperparameters used for the training job + environment={ # environment variables used during training + "COMET_PROJECT_NAME": os.getenv( + "COMET_PROJECT_NAME" + ), # comet project name + "COMET_API_KEY": os.getenv("COMET_API_KEY"), + "COMET_EXPERIMENT_KEY": self.experiment_key, + "HF_TOKEN": os.getenv( + "HF_TOKEN" + ), # required by some models, such as llama-3 or Mistral + "S3_MODEL_URI": self.training_conf.estimator.output_path, # the uri where the model artifact is stored is actually not in the SM_TRAINING_JOB environment variables. Let's add it. + # "S3_EVALUATION_URI": self.training_conf.additional_conf.s3_evaluation_uri, # s3 uri where evaluation prediction are stored }, - **self.training_conf.estimator, + **self.training_conf.estimator, ) # Run training job - estimator.fit(wait=True) # Add previous model - + estimator.fit(wait=True) # Add previous model + # Log Sagemaker training information into metaflow after training job self.sagemaker_training_job_id = estimator.latest_training_job.job_name - self.model_artifact_uri = self.training_conf.estimator.output_path + self.sagemaker_training_job_id + self.model_artifact_uri = ( + self.training_conf.estimator.output_path + self.sagemaker_training_job_id + ) # self.evaluation_uri = os.path.join( # self.training_conf.additional_conf.s3_evaluation_uri, # "evaluation-" + self.sagemaker_training_job_id # ) # Log model artifact in cometml - experiment.log_remote_model(model_name="model", model_path=self.model_artifact_uri, sync_mode=False) + experiment.log_remote_model( + model_name="model", model_path=self.model_artifact_uri, sync_mode=False + ) experiment.log_parameter("training_job_name", self.sagemaker_training_job_id) - + self.next(self.human_evaluation) @metaflow.step def human_evaluation(self): """Conditional step: - Push model predictions against the Benchmark to Argilla. + Push model predictions against the Benchmark to Argilla. Evaluation dataset was generated during the training step on the Sagemaker instance. """ if self.do_human_eval: # Get the latest experiment - self.argilla_dataset_name = self.training_conf.estimator.base_job_name + "-exp-key-" + self.experiment_key + self.argilla_dataset_name = ( + self.training_conf.estimator.base_job_name + + "-exp-key-" + + self.experiment_key + ) BenchmarkEvaluationArgilla.from_s3(self.evaluation_uri).deploy( dataset_name=self.argilla_dataset_name ) diff --git a/spellcheck/scripts/dataset/0_extract_data.py b/spellcheck/scripts/dataset/0_extract_data.py index 5baeb1a9..db26c8ae 100644 --- a/spellcheck/scripts/dataset/0_extract_data.py +++ b/spellcheck/scripts/dataset/0_extract_data.py @@ -19,7 +19,7 @@ "ingredients_text", "unknown_ingredients_n", "known_ingredients_n", - "ingredients_n" + "ingredients_n", ] DTYPES_MAPPING = { "ingredients_n": pl.Int16, @@ -36,7 +36,7 @@ def main(): if not DATA_PATH.is_file(): raise ValueError(f"Data path is not valid: {str(DATA_PATH)}") - + # Load benchmark and dataset to remove duplicates benchmark_df = pl.read_parquet(BENCHMARK_PATH) previous_dataset = load_dataset(HF_DATASET_ID, split="train+test") @@ -55,8 +55,9 @@ def main(): LOGGER.info(f"The extracted dataset contains {len(df)} rows.") df.write_parquet(OUTPUT_DATA_PATH) + @timer -def extract_data( +def extract_data( df: pl.LazyFrame, percentage_unknown_range: Tuple[float, float], keep_features: List[str], @@ -68,7 +69,7 @@ def extract_data( ) -> pl.DataFrame: """Extracts products from the JSONL database based on their percentage unknown range. - Notes: + Notes: Take around 8 minutes to run. Args: @@ -89,11 +90,26 @@ def extract_data( output_df = ( df.select(pl.col(*keep_features)) .drop_nulls() - .with_columns((pl.col("unknown_ingredients_n") / pl.col("ingredients_n")).alias("fraction")) - .filter((pl.col("fraction") >= percentage_min) & (pl.col("fraction") <= percentage_max)) - .unique(subset=["ingredients_text",]) # Remove duplicates within the dataset - .filter(~pl.col("ingredients_text").is_in(benchmark_df["original"])) # Remove duplicates with Benchmark - .filter(~pl.col("ingredients_text").is_in(previous_dataset["text"])) # Remove duplicates with previous dataset + .with_columns( + (pl.col("unknown_ingredients_n") / pl.col("ingredients_n")).alias( + "fraction" + ) + ) + .filter( + (pl.col("fraction") >= percentage_min) + & (pl.col("fraction") <= percentage_max) + ) + .unique( + subset=[ + "ingredients_text", + ] + ) # Remove duplicates within the dataset + .filter( + ~pl.col("ingredients_text").is_in(benchmark_df["original"]) + ) # Remove duplicates with Benchmark + .filter( + ~pl.col("ingredients_text").is_in(previous_dataset["text"]) + ) # Remove duplicates with previous dataset .collect(streaming=True) .sample(n=dataset_size, shuffle=True, seed=seed) .cast(dtype_output_mapping) diff --git a/spellcheck/scripts/dataset/1_generate_synthetic_data.py b/spellcheck/scripts/dataset/1_generate_synthetic_data.py index f0954ffd..e5a45185 100644 --- a/spellcheck/scripts/dataset/1_generate_synthetic_data.py +++ b/spellcheck/scripts/dataset/1_generate_synthetic_data.py @@ -23,7 +23,7 @@ def main(): - + df = pd.read_parquet(DATA_PATH) existing_codes = ( pd.read_json(SYNTHETIC_DATA_PATH, lines=True)["code"].to_list() @@ -34,7 +34,7 @@ def main(): model=OpenAIChatCompletion( prompt_template=Prompt.spellcheck_prompt_template, system_prompt=SystemPrompt.spellcheck_system_prompt, - model_name=MODEL_NAME + model_name=MODEL_NAME, ) ) generate_synthetic_data( @@ -43,7 +43,7 @@ def main(): existing_codes=existing_codes, spellcheck=spellcheck, original_text_feature="ingredients_text", - synthetic_feature = 'corrected_text' + synthetic_feature="corrected_text", ) @@ -53,14 +53,14 @@ def generate_synthetic_data( existing_codes: List, spellcheck: Spellcheck, original_text_feature: str, - synthetic_feature: str + synthetic_feature: str, ) -> None: """Generate synthetic data for text-based features using a spellcheck. Notes: - This function appends synthetic data to an existing file specified by output_data_path. - Each row in the DataFrame is processed, and if the code is not in the list of existing codes, the text in the original_text_feature column is corrected using the spellcheck and appended to the output file along with other data. - + Parameters: - df (pd.DataFrame): Input DataFrame containing the original data. - output_data_path (Path): Path to save the synthetic data. @@ -77,9 +77,11 @@ def generate_synthetic_data( LOGGER.info("Product was already generated. Pass.") else: row[synthetic_feature] = spellcheck.correct(row[original_text_feature]) - json.dump(row.to_dict(), file, ensure_ascii=False) # Ensure ascii for accents + json.dump( + row.to_dict(), file, ensure_ascii=False + ) # Ensure ascii for accents file.write("\n") - file.flush() # Immediatly write the line into the file + file.flush() # Immediatly write the line into the file LOGGER.info("Synthetic generation finished.") diff --git a/spellcheck/scripts/dataset/2_convert_to_dataset.py b/spellcheck/scripts/dataset/2_convert_to_dataset.py index b2ae1cdf..30ab8ce0 100644 --- a/spellcheck/scripts/dataset/2_convert_to_dataset.py +++ b/spellcheck/scripts/dataset/2_convert_to_dataset.py @@ -29,4 +29,4 @@ def main(): if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/spellcheck/scripts/evaluation/evaluate.py b/spellcheck/scripts/evaluation/evaluate.py index 5dc97a7e..2b5b387f 100644 --- a/spellcheck/scripts/evaluation/evaluate.py +++ b/spellcheck/scripts/evaluation/evaluate.py @@ -10,15 +10,22 @@ from spellcheck.utils import get_repo_dir, get_logger from spellcheck.spellcheck import Spellcheck from spellcheck.model import ( - AnthropicChatCompletion, - OpenAIChatCompletion, - RulesBasedModel, - GeminiModel, + AnthropicChatCompletion, + OpenAIChatCompletion, + RulesBasedModel, + GeminiModel, LLMInferenceEndpoint, ) from spellcheck.prompt import SystemPrompt, Prompt -from spellcheck.argilla.deployment import BenchmarkEvaluationArgilla, IngredientsCompleteEvaluationArgilla -from spellcheck.evaluation.evaluation import Evaluate, import_benchmark, import_ingredients_complete +from spellcheck.argilla.deployment import ( + BenchmarkEvaluationArgilla, + IngredientsCompleteEvaluationArgilla, +) +from spellcheck.evaluation.evaluation import ( + Evaluate, + import_benchmark, + import_ingredients_complete, +) REPO_DIR = get_repo_dir() @@ -34,25 +41,41 @@ INGREDIENTS_COMPLETE_VERSION = "v1" # Predictions JSONL paths to study the results -PREDICTIONS_EVALUATION_PATH = REPO_DIR / "data/evaluation/" / ( - MODEL_NAME - + "-benchmark-" + BENCHMARK_VERSION - + "-prompt-" + PROMPT_VERSION - + ".jsonl" +PREDICTIONS_EVALUATION_PATH = ( + REPO_DIR + / "data/evaluation/" + / ( + MODEL_NAME + + "-benchmark-" + + BENCHMARK_VERSION + + "-prompt-" + + PROMPT_VERSION + + ".jsonl" + ) ) -PREDICTION_INGREDIENTS_COMPLETE_PATH = REPO_DIR / "data/evaluation" / ( - MODEL_NAME - + "-ingredients-complete-data-" + INGREDIENTS_COMPLETE_VERSION - + "-prompt-" + PROMPT_VERSION - + ".jsonl" +PREDICTION_INGREDIENTS_COMPLETE_PATH = ( + REPO_DIR + / "data/evaluation" + / ( + MODEL_NAME + + "-ingredients-complete-data-" + + INGREDIENTS_COMPLETE_VERSION + + "-prompt-" + + PROMPT_VERSION + + ".jsonl" + ) ) -START = 0 # To restart the run +START = 0 # To restart the run WAIT = 0 # Replace for gpt3.5 => "." not accepted by Argilla -ARGILLA_BENCHMARK_DATASET_NAME = f"Evaluation-{MODEL_NAME}-benchmark-{BENCHMARK_VERSION}-prompt-{PROMPT_VERSION}".replace(".", "") -ARGILLA_INGREDIENTS_COMPLETE_DATASET_NAME = f"Evaluation-{MODEL_NAME}-ingredients-complete-{INGREDIENTS_COMPLETE_VERSION}-prompt-{PROMPT_VERSION}".replace(".", "") +ARGILLA_BENCHMARK_DATASET_NAME = f"Evaluation-{MODEL_NAME}-benchmark-{BENCHMARK_VERSION}-prompt-{PROMPT_VERSION}".replace( + ".", "" +) +ARGILLA_INGREDIENTS_COMPLETE_DATASET_NAME = f"Evaluation-{MODEL_NAME}-ingredients-complete-{INGREDIENTS_COMPLETE_VERSION}-prompt-{PROMPT_VERSION}".replace( + ".", "" +) LOGGER = get_logger("INFO") @@ -60,9 +83,9 @@ def main(): - spellcheck=Spellcheck( + spellcheck = Spellcheck( model=OpenAIChatCompletion( - prompt_template=Prompt.spellcheck_prompt_template, #If Claude, use custom prompt template + prompt_template=Prompt.spellcheck_prompt_template, # If Claude, use custom prompt template system_prompt=SystemPrompt.spellcheck_system_prompt, model_name=MODEL_NAME, ) @@ -70,8 +93,7 @@ def main(): ####################### Evaluate on benchmark originals, references, metadata = import_benchmark( - path=BENCHMARK_PATH, - start_from=START + path=BENCHMARK_PATH, start_from=START ) evaluation = Evaluate( model_name=MODEL_NAME, @@ -85,15 +107,13 @@ def main(): references=references, metadata=metadata, spellcheck=spellcheck, - wait=WAIT + wait=WAIT, ) evaluation.compute_metrics() # Human evaluation - BenchmarkEvaluationArgilla.from_jsonl( - path=PREDICTIONS_EVALUATION_PATH - ).deploy( - dataset_name=ARGILLA_BENCHMARK_DATASET_NAME) - + BenchmarkEvaluationArgilla.from_jsonl(path=PREDICTIONS_EVALUATION_PATH).deploy( + dataset_name=ARGILLA_BENCHMARK_DATASET_NAME + ) # ####################### Evaluate on Ingredient complete dataset # originals, references, metadata = import_ingredients_complete(path=INGREDIENTS_COMPLETE_DATA_PATH) diff --git a/spellcheck/scripts/old_to_new/0_convert_old_data.py b/spellcheck/scripts/old_to_new/0_convert_old_data.py index e5b52851..77e8fe55 100644 --- a/spellcheck/scripts/old_to_new/0_convert_old_data.py +++ b/spellcheck/scripts/old_to_new/0_convert_old_data.py @@ -1,5 +1,6 @@ """Extract the old data prepared by Lucain W and save it as a json file for future work. """ + import os from pathlib import Path from typing import Iterator, Mapping, List @@ -46,14 +47,14 @@ def postprocess_texts( original_texts: Iterator[str], reference_texts: Iterator[str], unaccepted_string: str = "NOT_VALID", - lang: str = "fr" + lang: str = "fr", ) -> List[Mapping]: - """Map original and reference texts. Remove + """Map original and reference texts. Remove Args: original_texts (Iterator[str]): Before Spellchecking reference_texts (Iterator[str]): After Spellchecking - unaccepted_string (str, optional): Some `After spellcheck` data were considered as not valid. + unaccepted_string (str, optional): Some `After spellcheck` data were considered as not valid. We remove them. Defaults to "NOT_VALID". lang (str, optional): Langue. Defaults to "fr". @@ -84,7 +85,7 @@ def main(): convert_old_data( original_path=ORIGINAL_DATA_PATH, reference_path=REFERENCE_DATA_PATH, - save_path=OUTPUT_DATA_PATH + save_path=OUTPUT_DATA_PATH, ) diff --git a/spellcheck/scripts/training/flan-t5/flan-t5.py b/spellcheck/scripts/training/flan-t5/flan-t5.py index 7a4584d8..b75ef826 100644 --- a/spellcheck/scripts/training/flan-t5/flan-t5.py +++ b/spellcheck/scripts/training/flan-t5/flan-t5.py @@ -1,4 +1,5 @@ """Flan-T5 training script.""" + import os import sys import json @@ -16,7 +17,7 @@ from datasets import load_dataset, disable_caching import numpy as np -from spellcheck.evaluation.evaluator import SpellcheckEvaluator +from spellcheck.evaluation.evaluator import SpellcheckEvaluator # For testing load_dotenv() @@ -25,19 +26,21 @@ LOGGER = logging.getLogger(__name__) logging.basicConfig( - level=logging.getLevelName("INFO"), - handlers=[logging.StreamHandler(sys.stdout)], # Get training logging during training job on Sagemaker - format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", + level=logging.getLevelName("INFO"), + handlers=[ + logging.StreamHandler(sys.stdout) + ], # Get training logging during training job on Sagemaker + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", ) # Retrieve Sagemaker job name to get the model artifact from S3 SM_TRAINING_ENV = json.loads(os.getenv("SM_TRAINING_ENV")) # Need to be deserialized SM_JOB_NAME = SM_TRAINING_ENV["job_name"] -# Where the model artifact is stored +# Where the model artifact is stored S3_MODEL_URI = os.getenv("S3_MODEL_URI") -# Tags. JSON Serialized as a string because List is not serializable +# Tags. JSON Serialized as a string because List is not serializable EXPERIMENT_TAGS = os.getenv("EXPERIMENT_TAGS").split(",") @@ -45,41 +48,107 @@ def parse_args(): parser = argparse.ArgumentParser() # Sagemaker environment - parser.add_argument("--training_data", type=str, default=os.getenv("SM_CHANNEL_TRAINING_DATA")) # "SM_CHANNEL_{name_data}" - parser.add_argument("--evaluation_data", type=str, default=os.getenv("SM_CHANNEL_EVALUATION_DATA")) + parser.add_argument( + "--training_data", type=str, default=os.getenv("SM_CHANNEL_TRAINING_DATA") + ) # "SM_CHANNEL_{name_data}" + parser.add_argument( + "--evaluation_data", type=str, default=os.getenv("SM_CHANNEL_EVALUATION_DATA") + ) parser.add_argument("--output_dir", type=str, default=os.getenv("SM_MODEL_DIR")) - #Training - parser.add_argument("--pretrained_model_name", type=str, help="Pretrained model id to fine-tune from the Hugging Face Hub.") - parser.add_argument("--num_train_epochs", type=float, default=1, help="Number of epochs.") - parser.add_argument("--per_device_train_batch_size", type=int, default=8, help="Training batch size.") - parser.add_argument("--per_device_eval_batch_size", type=int, default=8, help="Eval batch size.") + # Training + parser.add_argument( + "--pretrained_model_name", + type=str, + help="Pretrained model id to fine-tune from the Hugging Face Hub.", + ) + parser.add_argument( + "--num_train_epochs", type=float, default=1, help="Number of epochs." + ) + parser.add_argument( + "--per_device_train_batch_size", + type=int, + default=8, + help="Training batch size.", + ) + parser.add_argument( + "--per_device_eval_batch_size", type=int, default=8, help="Eval batch size." + ) parser.add_argument("--lr", type=float, default=5e-5, help="Learning rate.") parser.add_argument("--seed", type=int, default=42, help="Seed.") - parser.add_argument("--warmup_steps", type=int, default=0, help="Number of steps used for a linear warmup from 0 to `learning_rate`") + parser.add_argument( + "--warmup_steps", + type=int, + default=0, + help="Number of steps used for a linear warmup from 0 to `learning_rate`", + ) parser.add_argument("--warmup_ratio", type=float, default=0, help="Warm-up ratio.") - parser.add_argument("--weight_decay", type=float, default=0, help="Weight decay to prevent overfitting") - parser.add_argument("--gradient_checkpointing", type=strtobool, default=False, help="To reduce GPU memory footprint during training") - parser.add_argument("--fp16", type=strtobool, default=False, help="Whether to use bf16.") - parser.add_argument("--generation_max_tokens", type=int, default=512, help="Max tokens used for text generation in the Trainer module.") + parser.add_argument( + "--weight_decay", + type=float, + default=0, + help="Weight decay to prevent overfitting", + ) + parser.add_argument( + "--gradient_checkpointing", + type=strtobool, + default=False, + help="To reduce GPU memory footprint during training", + ) + parser.add_argument( + "--fp16", type=strtobool, default=False, help="Whether to use bf16." + ) + parser.add_argument( + "--generation_max_tokens", + type=int, + default=512, + help="Max tokens used for text generation in the Trainer module.", + ) parser.add_argument("--optim", type=str, default="adamw_torch", help="Optimizer.") - parser.add_argument("--lr_scheduler_type", type=str, default="linear", help="Learning scheduler type.") - parser.add_argument("--gradient_accumulation_steps", type=int, default=1, help="Accumulate bacthes before back propagation.") - parser.add_argument("--instruction", type=str, default="none", help="Flan-T5 instruction.") + parser.add_argument( + "--lr_scheduler_type", + type=str, + default="linear", + help="Learning scheduler type.", + ) + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Accumulate bacthes before back propagation.", + ) + parser.add_argument( + "--instruction", type=str, default="none", help="Flan-T5 instruction." + ) # Versions - parser.add_argument("--training_data_version", type=str, default="v0", help="Training dataset version.") - parser.add_argument("--evaluation_data_version", type=str, default="v0", help="Evaluation dataset version.") + parser.add_argument( + "--training_data_version", + type=str, + default="v0", + help="Training dataset version.", + ) + parser.add_argument( + "--evaluation_data_version", + type=str, + default="v0", + help="Evaluation dataset version.", + ) # Evaluation - parser.add_argument("--beta", type=float, default=1, help="Coefficient used in f1-beta score. beta < 1 favors Precision over Recall.") + parser.add_argument( + "--beta", + type=float, + default=1, + help="Coefficient used in f1-beta score. beta < 1 favors Precision over Recall.", + ) args = parser.parse_known_args() return args def copy_files(dir: str, *filenames: Iterable[str]) -> None: - """Copy additional files into the model.tar.gz artifact. + """Copy additional files into the model.tar.gz artifact. Args: path (str): SM_MODEL_DIR / code @@ -87,16 +156,17 @@ def copy_files(dir: str, *filenames: Iterable[str]) -> None: os.makedirs(dir, exist_ok=True) for filename in filenames: shutil.copyfile( - os.path.join(os.path.dirname(__file__), filename), # Source dir - os.path.join(dir, filename) # Output_dir + os.path.join(os.path.dirname(__file__), filename), # Source dir + os.path.join(dir, filename), # Output_dir ) class FlanT5Training: - """Flan-T5 training. - """ + """Flan-T5 training.""" - padding = "max_length" # Padding configuration. "max_length" means the moodel maxm length + padding = ( + "max_length" # Padding configuration. "max_length" means the moodel maxm length + ) def train(self, args): """Training. @@ -108,38 +178,38 @@ def train(self, args): tokenizer = AutoTokenizer.from_pretrained(args.pretrained_model_name) model = AutoModelForSeq2SeqLM.from_pretrained( args.pretrained_model_name, - use_cache= False if args.gradient_checkpointing else True, + use_cache=False if args.gradient_checkpointing else True, ) # Load data - train_dataset = load_dataset(path=args.training_data) + train_dataset = load_dataset(path=args.training_data) evaluation_dataset = load_dataset(path=args.evaluation_data, split="train") LOGGER.info(f"Training dataset: {train_dataset}") LOGGER.info(f"Evaluation dataset: {evaluation_dataset}") - + # Add instruction for Flan-T5 self.instruction = "" if args.instruction == "none" else args.instruction # Prepare datasets for training preprocessed_train_dataset = train_dataset.map( - self.preprocess, - batched=True, - remove_columns=train_dataset["train"].column_names, + self.preprocess, + batched=True, + remove_columns=train_dataset["train"].column_names, fn_kwargs={ - "input_name": "text", + "input_name": "text", "target_name": "label", "tokenizer": tokenizer, - } + }, ) preprocessed_evaluation_dataset = evaluation_dataset.map( self.preprocess, batched=True, remove_columns=evaluation_dataset.column_names, fn_kwargs={ - "input_name": "original", + "input_name": "original", "target_name": "reference", "tokenizer": tokenizer, - } + }, ) # Ignore tokenizer pad_token in the loss compuation @@ -149,43 +219,43 @@ def train(self, args): tokenizer, model=model, label_pad_token_id=label_pad_token_id, - pad_to_multiple_of=8 + pad_to_multiple_of=8, ) # Training training_args = Seq2SeqTrainingArguments( - output_dir = args.output_dir, # Model checkpoints directory - per_device_train_batch_size = args.per_device_train_batch_size, - per_device_eval_batch_size = args.per_device_eval_batch_size, - predict_with_generate = True, # Required for Seq2Seq - generation_max_length = args.generation_max_tokens, # Default to 20 (depends on the task) - fp16 = args.fp16, # Overflows with fp16 - learning_rate = args.lr, # https://huggingface.co/docs/transformers/en/model_doc/t5#training:~:text=Additional%20training%20tips%3A - num_train_epochs = args.num_train_epochs, - warmup_steps = args.warmup_steps, - warmup_ratio = args.warmup_ratio, - weight_decay = args.weight_decay, - gradient_checkpointing = args.gradient_checkpointing, - optim = args.optim, # AdamW or AdaFactor - lr_scheduler_type = args.lr_scheduler_type, - gradient_accumulation_steps = args.gradient_accumulation_steps, - #Logging & evaluation strategies - logging_dir = f"{args.output_dir}/logs", - logging_strategy = "steps", - logging_steps = 100, - evaluation_strategy = "epoch", - save_strategy = "epoch", - save_total_limit = 1, - load_best_model_at_end = True, + output_dir=args.output_dir, # Model checkpoints directory + per_device_train_batch_size=args.per_device_train_batch_size, + per_device_eval_batch_size=args.per_device_eval_batch_size, + predict_with_generate=True, # Required for Seq2Seq + generation_max_length=args.generation_max_tokens, # Default to 20 (depends on the task) + fp16=args.fp16, # Overflows with fp16 + learning_rate=args.lr, # https://huggingface.co/docs/transformers/en/model_doc/t5#training:~:text=Additional%20training%20tips%3A + num_train_epochs=args.num_train_epochs, + warmup_steps=args.warmup_steps, + warmup_ratio=args.warmup_ratio, + weight_decay=args.weight_decay, + gradient_checkpointing=args.gradient_checkpointing, + optim=args.optim, # AdamW or AdaFactor + lr_scheduler_type=args.lr_scheduler_type, + gradient_accumulation_steps=args.gradient_accumulation_steps, + # Logging & evaluation strategies + logging_dir=f"{args.output_dir}/logs", + logging_strategy="steps", + logging_steps=100, + evaluation_strategy="epoch", + save_strategy="epoch", + save_total_limit=1, + load_best_model_at_end=True, # metric_for_best_model = "f1_beta", # Metric used to select the best model. report_to="comet_ml", ) trainer = Seq2SeqTrainer( - model = model, - args = training_args, - data_collator = data_collator, - train_dataset = preprocessed_train_dataset["train"], - eval_dataset = preprocessed_train_dataset["test"], + model=model, + args=training_args, + data_collator=data_collator, + train_dataset=preprocessed_train_dataset["train"], + eval_dataset=preprocessed_train_dataset["test"], ) trainer.train() LOGGER.info("Training finished.") @@ -199,26 +269,43 @@ def train(self, args): # This process is required since the a bug with CometML shuts down connection to the experiment run experiment = comet_ml.get_global_experiment() - LOGGER.info(f"Experiment name after Transformers trainer: {experiment.get_name()}") + LOGGER.info( + f"Experiment name after Transformers trainer: {experiment.get_name()}" + ) experiment = comet_ml.ExistingExperiment(experiment_key=experiment.get_key()) # Run, evaluate and upload benchmark predictions to S3 - evaluator = SpellcheckEvaluator(originals=evaluation_dataset["original"], beta=args.beta) + evaluator = SpellcheckEvaluator( + originals=evaluation_dataset["original"], beta=args.beta + ) LOGGER.info("Start evaluating model on benchmark.") - preds, _, _ = trainer.predict(preprocessed_evaluation_dataset, repetiton_penalty=1.03) - predictions = np.where(preds != -100, preds, tokenizer.pad_token_id) # DataCollator pad tokens with -100 to match labels. See predict() + preds, _, _ = trainer.predict( + preprocessed_evaluation_dataset, repetiton_penalty=1.03 + ) + predictions = np.where( + preds != -100, preds, tokenizer.pad_token_id + ) # DataCollator pad tokens with -100 to match labels. See predict() predictions = tokenizer.batch_decode(predictions, skip_special_tokens=True) - predictions = [prediction.strip() for prediction in predictions] # Remove and strip empty strings - metrics = evaluator.evaluate(predictions=predictions, references=evaluation_dataset["reference"]) + predictions = [ + prediction.strip() for prediction in predictions + ] # Remove and strip empty strings + metrics = evaluator.evaluate( + predictions=predictions, references=evaluation_dataset["reference"] + ) LOGGER.info(f"Evaluation metrics: {metrics}") experiment.log_metrics(metrics) # Upload predictions to S3 - prediction_dataset = evaluation_dataset.add_column(name="prediction", column=predictions) - s3_evaluation_path = os.path.join(os.getenv("S3_EVALUATION_URI"), "evaluation-" + SM_JOB_NAME) - LOGGER.info(f"S3 URI where predictions on evaluation are sent to: {s3_evaluation_path}") + prediction_dataset = evaluation_dataset.add_column( + name="prediction", column=predictions + ) + s3_evaluation_path = os.path.join( + os.getenv("S3_EVALUATION_URI"), "evaluation-" + SM_JOB_NAME + ) + LOGGER.info( + f"S3 URI where predictions on evaluation are sent to: {s3_evaluation_path}" + ) prediction_dataset.save_to_disk(s3_evaluation_path) - - + # Experiment tags LOGGER.info(f"Log tags: {EXPERIMENT_TAGS}") experiment.add_tags(EXPERIMENT_TAGS) @@ -227,42 +314,42 @@ def train(self, args): model_uri = os.path.join(S3_MODEL_URI, SM_JOB_NAME, "output/model.tar.gz") LOGGER.info(f"Training job uri: {model_uri}") experiment.log_remote_model( - "flan-t5-small-spellcheck", - model_uri, - sync_mode=False + "flan-t5-small-spellcheck", model_uri, sync_mode=False ) # Log dataset lengths - experiment.log_parameters({ - "training_dataset_length": len(train_dataset), - "evaluation_dataset_length": len(evaluation_dataset), - }) + experiment.log_parameters( + { + "training_dataset_length": len(train_dataset), + "evaluation_dataset_length": len(evaluation_dataset), + } + ) # Log Metaflow run id if metaflow_run_id := os.getenv("METAFLOW_RUN_ID"): experiment.log_parameter("metaflow_run_id", metaflow_run_id) - + LOGGER.info("Training job finished.") def preprocess( - self, - sample: Mapping, - input_name: str, - target_name: str, - tokenizer: AutoTokenizer - ) -> Mapping: + self, + sample: Mapping, + input_name: str, + target_name: str, + tokenizer: AutoTokenizer, + ) -> Mapping: """Preprocess dataset using the `map()` function. Args: sample (Mapping): Batch of the dataset input_name (str): Model training text input feature. - target_name (str): Model training text label feature. + target_name (str): Model training text label feature. tokenizer (AutoTokenizer): Tokenizer Returns: Mapping: Processed batch """ - + # add prefix to the input for t5 inputs = [self.instruction + item for item in sample[input_name]] # tokenize inputs @@ -273,17 +360,15 @@ def preprocess( # padding in the loss. if self.padding == "max_length": labels["input_ids"] = [ - [(l if l != tokenizer.pad_token_id else -100) for l in label] for label in labels["input_ids"] + [(l if l != tokenizer.pad_token_id else -100) for l in label] + for label in labels["input_ids"] ] model_inputs["labels"] = labels["input_ids"] return model_inputs if __name__ == "__main__": - + args, _ = parse_args() FlanT5Training().train(args) - copy_files( - os.path.join(args.output_dir, "code"), - "requirements.txt" - ) \ No newline at end of file + copy_files(os.path.join(args.output_dir, "code"), "requirements.txt") diff --git a/spellcheck/scripts/training/llm/llm.py b/spellcheck/scripts/training/llm/llm.py index 453c9fec..b8625695 100644 --- a/spellcheck/scripts/training/llm/llm.py +++ b/spellcheck/scripts/training/llm/llm.py @@ -3,7 +3,7 @@ from dotenv import load_dotenv import torch -from datasets import( +from datasets import ( load_dataset, disable_caching, ) @@ -55,32 +55,34 @@ def main(): # SETUP ###################### LOGGER.info("Parse information from CLI using Argparser.") - parser = HfArgumentParser([ - SFTConfig, - # BitsAndBytesConfig, - # LoraConfig, - SFTDataProcessingConfig, - TrainingDataFeatures, - EvaluationDataFeatures, - ModelConfig, - DataConfig, - SavingConfig, - InferenceConfig, - ]) + parser = HfArgumentParser( + [ + SFTConfig, + # BitsAndBytesConfig, + # LoraConfig, + SFTDataProcessingConfig, + TrainingDataFeatures, + EvaluationDataFeatures, + ModelConfig, + DataConfig, + SavingConfig, + InferenceConfig, + ] + ) ( - sft_config, - # quantization_config, - # lora_config, + sft_config, + # quantization_config, + # lora_config, data_processing_config, training_data_features, evaluation_data_features, - model_config, - data_config, - saving_config, + model_config, + data_config, + saving_config, inference_config, ) = parser.parse_args_into_dataclasses() - #NOTE: Bug with LoraConfig and HFArgumentParser (Only `Union[X, NoneType]` (i.e., `Optional[X]`) is allowed for `Union` because the argument parser only supports one type per argument. Problem encountered in field 'init_lora_weights'.) + # NOTE: Bug with LoraConfig and HFArgumentParser (Only `Union[X, NoneType]` (i.e., `Optional[X]`) is allowed for `Union` because the argument parser only supports one type per argument. Problem encountered in field 'init_lora_weights'.) # We instantiate LoraConfig "manually" lora_config = LoraConfig( lora_alpha=16, @@ -103,7 +105,7 @@ def main(): # Sagemaker environment variables: https://github.com/aws/sagemaker-training-toolkit/blob/master/ENVIRONMENT_VARIABLES.md OUTPUT_DIR = os.getenv("SM_MODEL_DIR") - #Comet experiment. Will be used in the CometCallback during the training + # Comet experiment. Will be used in the CometCallback during the training EXPERIMENT_KEY = os.getenv("COMET_EXPERIMENT_KEY") experiment = comet_ml.ExistingExperiment(previous_experiment=EXPERIMENT_KEY) @@ -112,7 +114,7 @@ def main(): ###################### LOGGER.info("Load datasets.") training_dataset = load_dataset( - path=data_config.training_data, + path=data_config.training_data, split=data_config.train_split, revision=data_config.train_data_revision, ) @@ -125,8 +127,7 @@ def main(): revision=data_config.eval_data_revision, ) datasets = Datasets( - training_dataset=training_dataset, - evaluation_dataset=evaluation_dataset + training_dataset=training_dataset, evaluation_dataset=evaluation_dataset ) LOGGER.info(f"Training dataset: {datasets.training_dataset}") LOGGER.info(f"Evaluation dataset: {datasets.evaluation_dataset}") @@ -141,11 +142,13 @@ def main(): processed_datasets = data_processor.process_datasets( datasets=datasets, training_data_features=training_data_features, - evaluation_data_features=evaluation_data_features + evaluation_data_features=evaluation_data_features, ) LOGGER.info(f"Processed training dataset: {processed_datasets.training_dataset}.") if processed_datasets.evaluation_dataset: - LOGGER.info(f"Processed evaluation dataset: {processed_datasets.evaluation_dataset}.") + LOGGER.info( + f"Processed evaluation dataset: {processed_datasets.evaluation_dataset}." + ) ###################### # MODEL PREPARATION @@ -174,7 +177,7 @@ def main(): # TRAIN ###################### LOGGER.info("Start training.") - # Since we parse config using Argument + # Since we parse config using Argument trainer = SFTTrainer( model=model, args=sft_config, @@ -187,7 +190,7 @@ def main(): # Thus we modified the callback to track experimentation on existing experiment trainer.add_callback(CustomCometCallback) trainer.train() - + ###################### # SAVING ###################### @@ -224,5 +227,6 @@ def main(): LOGGER.info("End of the training job.") + if __name__ == "__main__": main() diff --git a/spellcheck/scripts/training/llm/pretraining_llm.py b/spellcheck/scripts/training/llm/pretraining_llm.py index 3a90c9b5..fbddb79e 100644 --- a/spellcheck/scripts/training/llm/pretraining_llm.py +++ b/spellcheck/scripts/training/llm/pretraining_llm.py @@ -3,7 +3,7 @@ from dotenv import load_dotenv import torch -from datasets import( +from datasets import ( load_dataset, disable_caching, ) @@ -47,24 +47,26 @@ def main(): # SETUP ###################### LOGGER.info("Parse information from CLI using Argparser.") - parser = HfArgumentParser([ - SFTConfig, - # BitsAndBytesConfig, - # LoraConfig, - ModelConfig, - DataConfig, - SavingConfig, - ]) + parser = HfArgumentParser( + [ + SFTConfig, + # BitsAndBytesConfig, + # LoraConfig, + ModelConfig, + DataConfig, + SavingConfig, + ] + ) ( - sft_config, - # quantization_config, - # lora_config, - model_config, - data_config, - saving_config, + sft_config, + # quantization_config, + # lora_config, + model_config, + data_config, + saving_config, ) = parser.parse_args_into_dataclasses() - #NOTE: Bug with LoraConfig and HFArgumentParser (Only `Union[X, NoneType]` (i.e., `Optional[X]`) is allowed for `Union` because the argument parser only supports one type per argument. Problem encountered in field 'init_lora_weights'.) + # NOTE: Bug with LoraConfig and HFArgumentParser (Only `Union[X, NoneType]` (i.e., `Optional[X]`) is allowed for `Union` because the argument parser only supports one type per argument. Problem encountered in field 'init_lora_weights'.) # We instantiate LoraConfig "manually" lora_config = LoraConfig( lora_alpha=8, @@ -86,21 +88,27 @@ def main(): # Sagemaker environment variables: https://github.com/aws/sagemaker-training-toolkit/blob/master/ENVIRONMENT_VARIABLES.md OUTPUT_DIR = os.getenv("SM_MODEL_DIR") - SM_TRAINING_ENV = json.loads(os.getenv("SM_TRAINING_ENV")) # Need to be deserialized + SM_TRAINING_ENV = json.loads( + os.getenv("SM_TRAINING_ENV") + ) # Need to be deserialized SM_JOB_NAME = SM_TRAINING_ENV["job_name"] # Where the model artifact is store. Can be compressed (model.tar.gz) or decompressed (model/) S3_MODEL_URI = os.path.join(os.getenv("S3_MODEL_URI"), "output/model/") - #Comet experiment + # Comet experiment EXPERIMENT_KEY = os.getenv("COMET_EXPERIMENT_KEY") - experiment = comet_ml.ExistingExperiment(previous_experiment=EXPERIMENT_KEY) if EXPERIMENT_KEY else comet_ml.Experiment() + experiment = ( + comet_ml.ExistingExperiment(previous_experiment=EXPERIMENT_KEY) + if EXPERIMENT_KEY + else comet_ml.Experiment() + ) ###################### # LOAD DATA ###################### LOGGER.info("Load datasets.") training_dataset = load_dataset( - path=data_config.training_data, + path=data_config.training_data, split=data_config.train_split, revision=data_config.train_data_revision, ) @@ -135,7 +143,7 @@ def main(): # TRAIN ###################### LOGGER.info("Start training.") - # Since we parse config using Argument + # Since we parse config using Argument trainer = SFTTrainer( model=model, args=sft_config, @@ -147,7 +155,7 @@ def main(): # Thus we modified the callback to track experimentation on existing experiment trainer.add_callback(CustomCometCallback) trainer.train() - + ###################### # SAVING ###################### @@ -162,17 +170,20 @@ def main(): ###################### # EXPERIMENTATION LOGGING ###################### - LOGGER.info("Start logging additional metrics and parameters to the experiment tracker.") + LOGGER.info( + "Start logging additional metrics and parameters to the experiment tracker." + ) experiment_logger = CometExperimentLogger(experiment=experiment) experiment_logger.log( model_uri=S3_MODEL_URI, model_name="pretrained_model", parameters={ "pretraining_job_name": SM_JOB_NAME, - } + }, ) LOGGER.info("End of the training job.") + if __name__ == "__main__": main() diff --git a/spellcheck/src/spellcheck/__init__.py b/spellcheck/src/spellcheck/__init__.py index 90ed062b..f2d92d76 100644 --- a/spellcheck/src/spellcheck/__init__.py +++ b/spellcheck/src/spellcheck/__init__.py @@ -1 +1 @@ -"""Packaged modules to run Spellcheck dev repo.""" \ No newline at end of file +"""Packaged modules to run Spellcheck dev repo.""" diff --git a/spellcheck/src/spellcheck/argilla/deployment.py b/spellcheck/src/spellcheck/argilla/deployment.py index cbe7d6c0..994d5acb 100644 --- a/spellcheck/src/spellcheck/argilla/deployment.py +++ b/spellcheck/src/spellcheck/argilla/deployment.py @@ -20,14 +20,14 @@ class ArgillaModule(ABC): @abstractmethod def deploy(self, dataset_name: str, workspace_name: str = "spellcheck") -> None: - """Deploy Dataset into Argilla. + """Deploy Dataset into Argilla. Args: dataset_name (str): Argilla dataset name workspace_name (str, optional): Argilla workspace name. Defaults to "spellcheck". """ raise NotImplementedError - + @abstractmethod def _prepare_dataset(self) -> rg.FeedbackDataset: """Prepare Argilla Dataset architecture for annotation. @@ -36,13 +36,13 @@ def _prepare_dataset(self) -> rg.FeedbackDataset: rg.FeedbackDataset """ raise NotImplementedError - + @abstractmethod def _prepare_records(self) -> Iterable[rg.FeedbackRecord]: - """Records are prepared in respect of the preconfigured fields. + """Records are prepared in respect of the preconfigured fields. Returns: - Iterable[rg.FeedbackRecord]: Batch of records. + Iterable[rg.FeedbackRecord]: Batch of records. """ raise NotImplementedError @@ -51,13 +51,13 @@ def _prepare_records(self) -> Iterable[rg.FeedbackRecord]: def from_jsonl(cls, path: Path) -> None: """Load the data from a JSONL file.""" raise NotImplementedError - + @classmethod @abstractmethod def from_parquet(cls, path: Path) -> None: """Load the data from a parquet file.""" raise NotImplementedError - + @classmethod @abstractmethod def from_s3(cls, uri) -> None: @@ -67,84 +67,80 @@ def from_s3(cls, uri) -> None: class BenchmarkEvaluationArgilla(ArgillaModule): """Argilla module for model human evaluation step. - + Args: originals (Iterable[str]): Batch of original lists of ingredients references (Iterable[str]): Batch of references as annotated in the benchmark predictions (Iterable[str]): Batch of model predictions metadata (Iterable[Dict]): Batch of metadata associated with each list of ingredients """ + def __init__( self, originals: Iterable[str], references: Iterable[str], predictions: Iterable[str], - metadata: Iterable[Dict] + metadata: Iterable[Dict], ): self.originals = originals self.references = references self.predictions = predictions self.metadata = metadata - def deploy( - self, - dataset_name: str, - workspace_name: str = "spellcheck" - ) -> None: + def deploy(self, dataset_name: str, workspace_name: str = "spellcheck") -> None: rg.init( - api_url=os.getenv("ARGILLA_API_URL"), - api_key=os.getenv("ARGILLA_API_KEY") + api_url=os.getenv("ARGILLA_API_URL"), api_key=os.getenv("ARGILLA_API_KEY") ) dataset = self._prepare_dataset() records = self._prepare_records() dataset.add_records(records=records) dataset.push_to_argilla(name=dataset_name, workspace=workspace_name) - def _prepare_dataset(self) -> rg.FeedbackDataset: dataset = rg.FeedbackDataset( fields=[ rg.TextField(name="original", title="Original", use_markdown=True), rg.TextField(name="reference", title="Reference", use_markdown=True), - rg.TextField(name="prediction", title="Prediction", use_markdown=True) + rg.TextField(name="prediction", title="Prediction", use_markdown=True), ], questions=[ rg.LabelQuestion( name="is_good", title="Is the correction correct?", - labels=["Good","Bad"], - required=True + labels=["Good", "Bad"], + required=True, ), rg.TextQuestion( - name="notes", - title="Explain your decision: ", - required=False - ) + name="notes", title="Explain your decision: ", required=False + ), ], metadata_properties=[ rg.TermsMetadataProperty(name="lang", title="Language"), ], ) return dataset - + def _prepare_records(self) -> Iterable[rg.FeedbackRecord]: records = [] for original, reference, prediction, metadata in zip( - self.originals, self.highlighted_references, self.highlighted_predictions, self.metadata + self.originals, + self.highlighted_references, + self.highlighted_predictions, + self.metadata, ): record = rg.FeedbackRecord( fields={ "original": original, "reference": reference, - "prediction": prediction + "prediction": prediction, }, metadata={ "lang": metadata.get("lang"), - } + }, ) records.append(record) return records - + @classmethod def from_jsonl(cls, path: Path): elements = load_jsonl(path) @@ -152,153 +148,151 @@ def from_jsonl(cls, path: Path): [element["original"] for element in elements], [element["reference"] for element in elements], [element["prediction"] for element in elements], - [element["metadata"] for element in elements] + [element["metadata"] for element in elements], ) - + @classmethod def from_parquet(path: Path) -> None: raise NotImplementedError - + @classmethod def from_s3(cls, uri: str) -> None: if os.path.splitext(uri)[-1]: - raise ValueError("The S3 uri should be directed to a Hugging Face Dataset folder.") + raise ValueError( + "The S3 uri should be directed to a Hugging Face Dataset folder." + ) dataset = datasets.load_from_disk(uri) return cls( dataset["original"], dataset["reference"], dataset["prediction"], - [{"lang": lang} for lang in dataset["lang"]] + [{"lang": lang} for lang in dataset["lang"]], ) - + @property def highlighted_references(self) -> Iterable[str]: - """Highlight references. - """ - return [show_diff(original, reference, color="yellow") for original, reference in zip(self.originals, self.references)] + """Highlight references.""" + return [ + show_diff(original, reference, color="yellow") + for original, reference in zip(self.originals, self.references) + ] @property def highlighted_predictions(self) -> Iterable[str]: - """Highlight predictions. - """ - return [show_diff(reference, prediction, color="red") for reference, prediction in zip(self.references, self.predictions)] - + """Highlight predictions.""" + return [ + show_diff(reference, prediction, color="red") + for reference, prediction in zip(self.references, self.predictions) + ] + @classmethod def from_dataset( - cls, - path: str, + cls, + path: str, original_feature: str = "original", reference_feature: str = "reference", - prediction_feature: str = "prediction" + prediction_feature: str = "prediction", ) -> None: dataset = datasets.load_from_disk(path) return cls( originals=dataset[original_feature], references=dataset[reference_feature], predictions=dataset[prediction_feature], - metadata=[{"lang": lang} for lang in dataset["lang"]] + metadata=[{"lang": lang} for lang in dataset["lang"]], ) - + class IngredientsCompleteEvaluationArgilla(ArgillaModule): """Prepare Ingredients-Complete dataset for False Positives verification. - + Args: originals (Iterable[str]): Batch of original lists of ingredients predictions (Iterable[str]): Batch of model predictions metadata (Iterable[Dict]): Batch of metadata associated with each list of ingredients """ - + def __init__( - self, - originals: Iterable[str], + self, + originals: Iterable[str], predictions: Iterable[str], - metadata: Iterable[Dict] + metadata: Iterable[Dict], ): self.originals = originals self.predictions = predictions self.metadata = metadata - - def deploy( - self, - dataset_name: str, - workspace_name: str = "spellcheck" - ) -> None: + + def deploy(self, dataset_name: str, workspace_name: str = "spellcheck") -> None: rg.init( - api_url=os.getenv("ARGILLA_API_URL"), - api_key=os.getenv("ARGILLA_API_KEY") + api_url=os.getenv("ARGILLA_API_URL"), api_key=os.getenv("ARGILLA_API_KEY") ) dataset = self._prepare_dataset() records = self._prepare_records() dataset.add_records(records=records) dataset.push_to_argilla(name=dataset_name, workspace=workspace_name) - + def _prepare_dataset(self) -> rg.FeedbackDataset: dataset = rg.FeedbackDataset( fields=[ rg.TextField(name="original", title="Original", use_markdown=True), - rg.TextField(name="prediction", title="Prediction", use_markdown=True) + rg.TextField(name="prediction", title="Prediction", use_markdown=True), ], questions=[ rg.LabelQuestion( name="is_good", title="Is the correction correct?", - labels=["Good","Bad"], - required=True + labels=["Good", "Bad"], + required=True, ), rg.TextQuestion( - name="notes", - title="Explain your decision: ", - required=False - ) + name="notes", title="Explain your decision: ", required=False + ), ], metadata_properties=[ rg.TermsMetadataProperty(name="lang", title="Language"), - rg.TermsMetadataProperty(name="code", title="Code") + rg.TermsMetadataProperty(name="code", title="Code"), ], ) return dataset - + def _prepare_records(self) -> Iterable[rg.FeedbackRecord]: records = [] for original, prediction, metadata in zip( self.originals, self.highlighted_predictions, self.metadata ): record = rg.FeedbackRecord( - fields={ - "original": original, - "prediction": prediction - }, + fields={"original": original, "prediction": prediction}, metadata={ "lang": metadata.get("lang"), - "code": str(metadata.get("code")) # String required instead of int - } + "code": str(metadata.get("code")), # String required instead of int + }, ) records.append(record) return records - + @classmethod def from_jsonl(cls, path: Path): elements = load_jsonl(path) return cls( [element["original"] for element in elements], [element["prediction"] for element in elements], - [element["metadata"] for element in elements] + [element["metadata"] for element in elements], ) - + @classmethod def from_parquet(path: Path) -> None: raise NotImplementedError - + @classmethod def from_s3(path: Path) -> None: raise NotImplementedError @property def highlighted_predictions(self): - """Highlight predictions. - """ - return [show_diff(original, prediction, color="red") for original, prediction in zip(self.originals, self.predictions)] + """Highlight predictions.""" + return [ + show_diff(original, prediction, color="red") + for original, prediction in zip(self.originals, self.predictions) + ] class BenchmarkArgilla(ArgillaModule): @@ -320,14 +314,9 @@ def __init__( self.references = references self.metadata = metadata - def deploy( - self, - dataset_name: str, - workspace_name: str = "spellcheck" - ): + def deploy(self, dataset_name: str, workspace_name: str = "spellcheck"): rg.init( - api_url=os.getenv("ARGILLA_API_URL"), - api_key=os.getenv("ARGILLA_API_KEY") + api_url=os.getenv("ARGILLA_API_URL"), api_key=os.getenv("ARGILLA_API_KEY") ) dataset = self._prepare_dataset() records = self._prepare_records() @@ -341,35 +330,40 @@ def _prepare_dataset(self) -> rg.FeedbackDataset: rg.TextField(name="original", title="Original"), ], questions=[ - rg.TextQuestion(name="reference", title="Correct the prediction.", use_markdown=True), + rg.TextQuestion( + name="reference", title="Correct the prediction.", use_markdown=True + ), ], metadata_properties=[ rg.TermsMetadataProperty(name="lang", title="Language"), ], ) return dataset - + def _prepare_records(self): records = [] - for original, reference, metadata in zip(self.originals, self.references, self.metadata): + for original, reference, metadata in zip( + self.originals, self.references, self.metadata + ): record = rg.FeedbackRecord( fields={ "original": original, - "code": metadata["code"] if metadata["code"] else "Code not available." + "code": ( + metadata["code"] if metadata["code"] else "Code not available." + ), }, suggestions=[ rg.SuggestionSchema( - question_name="reference", - value=show_diff(original, reference) + question_name="reference", value=show_diff(original, reference) ) ], metadata={ "lang": metadata.get("lang"), - } + }, ) records.append(record) return records - + @classmethod def from_parquet(cls, path: Path): """Load the data from a parquet file.""" @@ -378,13 +372,13 @@ def from_parquet(cls, path: Path): return cls( originals=df["original"].tolist(), references=df["reference"].tolist(), - metadata=metadata + metadata=metadata, ) - + @classmethod def from_jsonl(path: Path) -> None: raise NotImplementedError - + @classmethod def from_s3(path: Path) -> None: raise NotImplementedError @@ -397,7 +391,7 @@ def __init__( self, originals: Iterable[str], references: Iterable[str], - metadata: Iterable[Dict] + metadata: Iterable[Dict], ) -> None: self.originals = originals self.references = references @@ -405,8 +399,7 @@ def __init__( def deploy(self, dataset_name: str, workspace_name: str = "spellcheck"): rg.init( - api_url=os.getenv("ARGILLA_API_URL"), - api_key=os.getenv("ARGILLA_API_KEY") + api_url=os.getenv("ARGILLA_API_URL"), api_key=os.getenv("ARGILLA_API_KEY") ) dataset = self._prepare_dataset() records = self._prepare_records() @@ -419,26 +412,26 @@ def _prepare_dataset(self) -> rg.FeedbackDataset: rg.TextField(name="original", title="Original", use_markdown=True), ], questions=[ - rg.TextQuestion(name="reference", title="Correct the prediction.", use_markdown=True), + rg.TextQuestion( + name="reference", title="Correct the prediction.", use_markdown=True + ), rg.LabelQuestion( name="is_truncated", title="Is the list of ingredients truncated?", - labels=["YES","NO"], - required=False - ) + labels=["YES", "NO"], + required=False, + ), ], metadata_properties=[ rg.TermsMetadataProperty(name="lang", title="Language") ], ) return dataset - + def _prepare_records(self) -> Iterable[rg.FeedbackRecord]: records = [] for original, highlighted_reference, metadata in zip( - self.originals, - self.highlighted_references, - self.metadata + self.originals, self.highlighted_references, self.metadata ): record = rg.FeedbackRecord( fields={ @@ -446,20 +439,22 @@ def _prepare_records(self) -> Iterable[rg.FeedbackRecord]: }, suggestions=[ rg.SuggestionSchema( - question_name="reference", - value=highlighted_reference + question_name="reference", value=highlighted_reference ) ], metadata={ "lang": metadata.get("lang"), - } + }, ) records.append(record) return records - + @property def highlighted_references(self): - return [show_diff(original, reference) for original, reference in tqdm(zip(self.originals, self.references))] + return [ + show_diff(original, reference) + for original, reference in tqdm(zip(self.originals, self.references)) + ] @classmethod def from_jsonl(cls, path: Path): @@ -467,21 +462,21 @@ def from_jsonl(cls, path: Path): return cls( [element["original"] for element in elements], [element["reference"] for element in elements], - [element["metadata"] for element in elements] + [element["metadata"] for element in elements], ) - + @classmethod def from_parquet(path: Path) -> None: raise NotImplementedError - + @classmethod def from_s3(path: Path) -> None: raise NotImplementedError - + @classmethod def from_dataset( - cls, - hf_repo: str, + cls, + hf_repo: str, split: str = "train", original_feature: str = "original", reference_feature: str = "reference", @@ -490,5 +485,5 @@ def from_dataset( return cls( originals=dataset[original_feature], references=dataset[reference_feature], - metadata=[{"lang": lang} for lang in dataset["lang"]] + metadata=[{"lang": lang} for lang in dataset["lang"]], ) diff --git a/spellcheck/src/spellcheck/argilla/extraction.py b/spellcheck/src/spellcheck/argilla/extraction.py index fe252671..612c95e6 100644 --- a/spellcheck/src/spellcheck/argilla/extraction.py +++ b/spellcheck/src/spellcheck/argilla/extraction.py @@ -14,7 +14,7 @@ class ArgillaExtraction(ABC, BaseModel): - + dataset_name: str extracted_status: List[Literal["submitted", "pending", "draft", "discarded"]] workspace_name: str = "spellcheck" @@ -23,46 +23,47 @@ class ArgillaExtraction(ABC, BaseModel): def extract_dataset(self) -> Dataset: dataset = FeedbackDataset.from_argilla( - name=self.dataset_name, - workspace=self.workspace_name + name=self.dataset_name, workspace=self.workspace_name ).format_as("datasets") LOGGER.info(f"Dataset: {dataset}") processed_dataset = self._postprocess_dataset(dataset) LOGGER.info(f"Post-processed dataset: {processed_dataset}") return processed_dataset - + def _postprocess_dataset(self, dataset: Dataset) -> Dataset: - return ( - dataset - .filter(self._filter_fn) - .map(self._map_fn, batched=False, remove_columns=dataset.column_names) + return dataset.filter(self._filter_fn).map( + self._map_fn, batched=False, remove_columns=dataset.column_names ) - + def _remove_highlight_markdown(self, text: str) -> str: """Highlights were added during Argilla deployment to show corrections. They are removed during the extraction. """ - text = re.sub("]*)?>" + self.deleted_element + "<\/mark>", "", text) # # - # if an element was deleted - text = re.sub("<\/?mark(?:\s\w+[^>]*)?>", "", text) # - - + text = re.sub( + "]*)?>" + self.deleted_element + "<\/mark>", "", text + ) # # - # if an element was deleted + text = re.sub( + "<\/?mark(?:\s\w+[^>]*)?>", "", text + ) # - - return text - + @abstractmethod def _filter_fn(self, element: Mapping) -> Mapping: raise NotImplementedError - + @abstractmethod def _map_fn(self, element: Mapping) -> Mapping: raise NotImplementedError - + class SpellcheckExtraction(ArgillaExtraction): - """Benchmark and Training Dataset extraction. + """Benchmark and Training Dataset extraction. Here are some examples of the extracted dataset elements during the extraction process. Example 1: ``` 'url': 'https://world.openfoodfacts.org/product/5942262001416' - 'original': 'water:snow' + 'original': 'water:snow' 'reference': [{'user_id': 'dfb71753-1187-45e1-8006-629bef2b49e0', 'value': 'water:snow', 'status': 'discarded'}] 'reference-suggestion': 'water:snow' 'reference-suggestion-metadata': {'type': None, 'score': None, 'agent': None} @@ -79,8 +80,8 @@ class SpellcheckExtraction(ArgillaExtraction): 'original': 'Ananas, Ananassaft, Säuerungs - mittel: Citronensäure' 'reference': [ { - 'user_id': 'dfb71753-1187-45e1-8006-629bef2b49e0', - 'value': 'Ananas, Ananassaft, Säuerungsmittel: Citronensäure', + 'user_id': 'dfb71753-1187-45e1-8006-629bef2b49e0', + 'value': 'Ananas, Ananassaft, Säuerungsmittel: Citronensäure', 'status': 'submitted' } ] @@ -110,7 +111,11 @@ def _map_fn(self, element: Mapping[str, Any]) -> Mapping[str, Any]: Mapping[str, Any]: Processed element. """ # If status pending, we take the suggestion from the LLM - reference = element["reference"][0]["value"] if element["reference"] else element["reference-suggestion"] + reference = ( + element["reference"][0]["value"] + if element["reference"] + else element["reference-suggestion"] + ) postprocessed_reference = self._remove_highlight_markdown(reference) # Metadata is JSON encoded lang = json.loads(element["metadata"]).get("lang") @@ -119,9 +124,14 @@ def _map_fn(self, element: Mapping[str, Any]) -> Mapping[str, Any]: "reference": postprocessed_reference, "lang": lang, "code": element.get("code"), - "is_truncated": 0 if not element.get("is_truncated") or element["is_truncated"][0]["value"] == "NO" else 1 + "is_truncated": ( + 0 + if not element.get("is_truncated") + or element["is_truncated"][0]["value"] == "NO" + else 1 + ), } - + def _filter_fn(self, element: Mapping[str, Any]) -> bool: """Filter function applied to Dataset with Dataset.filter() @@ -132,14 +142,14 @@ def _filter_fn(self, element: Mapping[str, Any]) -> bool: bool: whether to keep (True) or drop (False) the element """ reference = element.get("reference") - # Status == Pending means no annotation were performed by annotator, but the LLM suggestion remains. + # Status == Pending means no annotation were performed by annotator, but the LLM suggestion remains. if not reference and "pending" in self.extracted_status: return True # Since it can be possible there are several annotators, we only take the last annotation if reference and reference[0]["status"] in self.extracted_status: return True return False - + class SpellcheckDPOExtraction(ArgillaExtraction): """Extract chosen and rejected correction from Argilla. This dataset is used to train a DPO (Direct Preference Optimization) model. @@ -147,13 +157,14 @@ class SpellcheckDPOExtraction(ArgillaExtraction): * 'Chosen': annotator modification. * 'Rejected': LLM original suggestion """ - + def __init__(self, **kwargs) -> None: - """DPO Extraction only works for "submitted" are in extracted_status. - """ + """DPO Extraction only works for "submitted" are in extracted_status.""" super().__init__(**kwargs) if "submitted" not in self.extracted_status: - raise ValueError(f"'Submitted' not in extracted_status. Current status: {self.extracted_status}") + raise ValueError( + f"'Submitted' not in extracted_status. Current status: {self.extracted_status}" + ) def _map_fn(self, element: Mapping[str, Any]) -> Mapping[str, Any]: """_summary_ @@ -177,7 +188,7 @@ def _map_fn(self, element: Mapping[str, Any]) -> Mapping[str, Any]: "rejected": postprocessed_rejected, "lang": lang, } - + def _filter_fn(self, element: Mapping[str, Any]) -> bool: """Filter function that considers only examples with submitted annotations. @@ -188,8 +199,8 @@ def _filter_fn(self, element: Mapping[str, Any]) -> bool: bool: Whether the row is kept. """ reference = element.get("reference") - if not reference: + if not reference: return False if reference[0]["status"] in self.extracted_status: return True - return False \ No newline at end of file + return False diff --git a/spellcheck/src/spellcheck/config.py b/spellcheck/src/spellcheck/config.py index 80aeb584..cdeed84b 100644 --- a/spellcheck/src/spellcheck/config.py +++ b/spellcheck/src/spellcheck/config.py @@ -3,9 +3,9 @@ @dataclass class ArgillaConfig: - deleted_element = "#" # Element to add in show_diff() to indicate a deleted element + deleted_element = "#" # Element to add in show_diff() to indicate a deleted element html_colors = { "yellow": "#FCF910", "orange": "#EA990C", "red": "#E94646", - } \ No newline at end of file + } diff --git a/spellcheck/src/spellcheck/evaluation/evaluation.py b/spellcheck/src/spellcheck/evaluation/evaluation.py index 7bb17ee0..65b76f38 100644 --- a/spellcheck/src/spellcheck/evaluation/evaluation.py +++ b/spellcheck/src/spellcheck/evaluation/evaluation.py @@ -1,4 +1,5 @@ """Evaluation module.""" + from pathlib import Path from typing import Iterable, Mapping, Tuple import time @@ -17,12 +18,11 @@ def import_benchmark( - path: Path, - start_from: int = 0 + path: Path, start_from: int = 0 ) -> Tuple[Iterable[str], Iterable[str], Iterable[Mapping]]: """Load benchmark. - It is possible a previous evaluation didn't go through the entire benchmark for many reasons. - In this case, the evaluation is restarted from a specific index instead of starting from the beginning. + It is possible a previous evaluation didn't go through the entire benchmark for many reasons. + In this case, the evaluation is restarted from a specific index instead of starting from the beginning. Args: path (Path): Benchmark as a parquet file @@ -31,7 +31,9 @@ def import_benchmark( Tuple[Iterable[str], Iterable[str], Iterable[Mapping]]: Text and Metadata from the benchmark """ if path.suffix != ".parquet": - raise ValueError(f"Wrong file format. Parquet required. Instead {path.suffix} provided") + raise ValueError( + f"Wrong file format. Parquet required. Instead {path.suffix} provided" + ) # Benchmark df = pd.read_parquet(path) # In case the begininning of the benchmark was already processed @@ -44,8 +46,7 @@ def import_benchmark( def import_ingredients_complete( - path: Path, - start_from: int = 0 + path: Path, start_from: int = 0 ) -> Tuple[Iterable[str], Iterable[str], Iterable[Mapping]]: """Load Ingredients complete dataset to evaluate Spellcheck on False Positives. @@ -58,14 +59,20 @@ def import_ingredients_complete( In this case, References are considered identical to Originals """ if path.suffix != ".parquet": - raise ValueError(f"Wrong file format. Parquet required. Instead {path.suffix} provided") + raise ValueError( + f"Wrong file format. Parquet required. Instead {path.suffix} provided" + ) df = pd.read_parquet(path) df = df.iloc[start_from:] LOGGER.info(f"Data features: {df.columns}") LOGGER.info(f"Data length: {len(df)}") originals = df["ingredients_text"].to_list() - references = originals.copy() # Reference = Original, which means the original is considered as perfect (no error to correct) - metadata = [{"lang": lang, "code": code} for lang, code in zip(df["lang"], df["code"])] + references = ( + originals.copy() + ) # Reference = Original, which means the original is considered as perfect (no error to correct) + metadata = [ + {"lang": lang, "code": code} for lang, code in zip(df["lang"], df["code"]) + ] return originals, references, metadata @@ -79,6 +86,7 @@ class Evaluate: benchmark_version (str): Version of the benchmark. predictions_path (Path): Path where all predictions against the benchmark are stored for further analysis. """ + def __init__( self, model_name: str, @@ -92,17 +100,17 @@ def __init__( self.benchmark_version = benchmark_version self.prompt_version = prompt_version self.predictions_path = predictions_path - + def run_evaluation( self, originals: Iterable[str], references: Iterable[str], spellcheck: Spellcheck, metadata: Iterable[Mapping], - wait: int = None + wait: int = None, ) -> None: """Run the Spellcheck module against the benchmark and store the predictions in predictions_path as a JSONL. - Addding predictions in a JSONL file prevents API request failures to erase the processed data. + Addding predictions in a JSONL file prevents API request failures to erase the processed data. Args: originals (Iterable[str]): ists of ingredients as seen on the website. @@ -115,10 +123,10 @@ def run_evaluation( LOGGER.info(f"Appending {str(self.predictions_path)} file.") with open(self.predictions_path, "a") as file: for original, reference, md in tqdm( - zip(originals, references, metadata), - desc="Evaluation against benchmark", - total=len(originals) - ): + zip(originals, references, metadata), + desc="Evaluation against benchmark", + total=len(originals), + ): timestamp = time.time() prediction = spellcheck.correct(original) md["latency"] = time.time() - timestamp @@ -126,25 +134,26 @@ def run_evaluation( "original": original, "reference": reference, "prediction": prediction, - "metadata": md + "metadata": md, } - json.dump(output, file, ensure_ascii=False) # Ensure ascii for accents + json.dump(output, file, ensure_ascii=False) # Ensure ascii for accents file.write("\n") - file.flush() # Immediatly write the line into the file - # In case Requests Per Minute are limited + file.flush() # Immediatly write the line into the file + # In case Requests Per Minute are limited if wait: time.sleep(wait) def compute_metrics(self) -> None: - """From the predictions JSONL containing the Spellcheck predictions, compute the metrics using the evaluation module. - """ + """From the predictions JSONL containing the Spellcheck predictions, compute the metrics using the evaluation module.""" with open(self.predictions_path, "r") as file: lines = file.readlines() elements = [json.loads(line) for line in lines] originals = [element["original"] for element in elements] references = [element["reference"] for element in elements] predictions = [element["prediction"] for element in elements] - evaluator = SpellcheckEvaluator(originals=originals) #TODO Remove the module call from the function + evaluator = SpellcheckEvaluator( + originals=originals + ) # TODO Remove the module call from the function metrics = evaluator.evaluate(predictions, references) metrics_output = { "metrics": metrics, @@ -152,8 +161,8 @@ def compute_metrics(self) -> None: "date": datetime.now().strftime("%d/%m/%Y %H:%M:%S"), "benchmark_version": self.benchmark_version, "prompt_version": self.prompt_version, - "benchmark_size": len(predictions) + "benchmark_size": len(predictions), } with open(self.metrics_path, "a") as file: json.dump(metrics_output, file, indent=4) - file.write("\n") + file.write("\n") diff --git a/spellcheck/src/spellcheck/evaluation/evaluator.py b/spellcheck/src/spellcheck/evaluation/evaluator.py index 3cb4b32a..abb27bc5 100644 --- a/spellcheck/src/spellcheck/evaluation/evaluator.py +++ b/spellcheck/src/spellcheck/evaluation/evaluator.py @@ -16,10 +16,7 @@ class Evaluator(ABC): @abstractmethod def evaluate( - self, - predictions: Iterable[str], - references: Iterable[str], - **kwargs + self, predictions: Iterable[str], references: Iterable[str], **kwargs ) -> Mapping: """Abstract method to calculate metrics based on the predictions and the evaluation method. @@ -48,8 +45,8 @@ class SpellcheckEvaluator(Evaluator): The process is divided into 4 steps: * Texts (Original-Reference-Prediction) are tokenized using a Byte Pair Encoding (BPE) tokenizer from the `tiktoken` library. - * Encoded originals and references are aligned using Sequence Alignment technique to locate which tokens were transformed, added, or deleted. - The same process is applied to encoded originals and predictions. + * Encoded originals and references are aligned using Sequence Alignment technique to locate which tokens were transformed, added, or deleted. + The same process is applied to encoded originals and predictions. * Pairs of tokens (Original-Reference; Original-Prediction) are aligned to consider gaps in case Reference and/or Prediction have different length. * Precision, Recall, and F1-score are calculated based on pairs of tokens for either the reference and the prediction. @@ -66,10 +63,10 @@ class SpellcheckEvaluator(Evaluator): originals (List[str]): Batch of original ingredient lists to correct references (List[str]): Batch of expected ingredients lists after correction encoding_name (str, optional): BPE tokenizer from the tiktoken library. Defaults to "cl100k_base". - beta (float, optional): Coefficient for F1_beta metric. A coefficient of less than 1.0 gives more weight to the Recall, + beta (float, optional): Coefficient for F1_beta metric. A coefficient of less than 1.0 gives more weight to the Recall, whereas a coefficient greater than 1.0 gives more weight to the Precision. drop_rate (foat, optional): Some predictions are (almost) empty, which means it would be irrelevant to compare them (alignment issue) - Therefore we drop to not bias the metrics. An additional metric is added to count them. + Therefore we drop to not bias the metrics. An additional metric is added to count them. """ def __init__( @@ -77,7 +74,7 @@ def __init__( originals: Iterable[str], encoding_name: str = "cl100k_base", beta: float = 1.0, - drop_rate: float = 0.4 + drop_rate: float = 0.4, ) -> None: self.originals = originals self.encoder = tiktoken.get_encoding(encoding_name=encoding_name) @@ -87,7 +84,7 @@ def __init__( def evaluate(self, predictions: List[str], references: List[str]) -> Mapping: """Evaluate the performance of Spellcheck on correcting ingredient lists for ingredients extraction. - Metrics: + Metrics: * Precision * Recall * F1 @@ -123,9 +120,9 @@ def evaluate(self, predictions: List[str], references: List[str]) -> Mapping: # Batch for original, reference, prediction in tqdm( - zip(normalized_originals, normalized_references, normalized_predictions), + zip(normalized_originals, normalized_references, normalized_predictions), total=len(predictions), - desc="Evaluation" + desc="Evaluation", ): # Convert into tokens. @@ -142,7 +139,7 @@ def evaluate(self, predictions: List[str], references: List[str]) -> Mapping: ref_pairs = self.sequence_alignment(original_tokens, reference_tokens) pred_pairs = self.sequence_alignment(original_tokens, prediction_tokens) - # Align ref-pairs and pred-pairs + # Align ref-pairs and pred-pairs aligned_ref_pairs, aligned_pred_pairs = self.align_pairs( ref_pairs, pred_pairs ) @@ -150,19 +147,24 @@ def evaluate(self, predictions: List[str], references: List[str]) -> Mapping: # Convert pairs into sparse matrices for metrics calculation sparse_ref_pairs = self.convert_pairs_into_sparse(aligned_ref_pairs) sparse_pred_pairs = self.convert_pairs_into_sparse(aligned_pred_pairs) - assert len(sparse_ref_pairs) == len(sparse_pred_pairs), "Ref and pred pairs don't have the same length!" + assert len(sparse_ref_pairs) == len( + sparse_pred_pairs + ), "Ref and pred pairs don't have the same length!" inverse_sparse_ref_pairs = [1 if i == 0 else 0 for i in sparse_ref_pairs] inverse_sparse_pred_pairs = [1 if i == 0 else 0 for i in sparse_pred_pairs] seq_true_positives = np.multiply(sparse_ref_pairs, sparse_pred_pairs) - seq_false_positives = np.multiply(inverse_sparse_ref_pairs, sparse_pred_pairs) - seq_false_negatives = np.multiply(sparse_ref_pairs, inverse_sparse_pred_pairs) + seq_false_positives = np.multiply( + inverse_sparse_ref_pairs, sparse_pred_pairs + ) + seq_false_negatives = np.multiply( + sparse_ref_pairs, inverse_sparse_pred_pairs + ) # Also check if model token predictions are correct seq_correction_true_positives = self.get_correction_true_positives( - ref_pairs=aligned_ref_pairs, - pred_pairs=aligned_pred_pairs + ref_pairs=aligned_ref_pairs, pred_pairs=aligned_pred_pairs ) true_positives.extend(seq_true_positives) @@ -177,14 +179,41 @@ def evaluate(self, predictions: List[str], references: List[str]) -> Mapping: correction_true_positive = np.sum(correction_true_positives) # Metrics calculation - precision = true_positive / (true_positive + false_positive) if (true_positive + false_positive) != 0 else 0 - recall = true_positive / (true_positive + false_negative) if (true_positive + false_negative) != 0 else 0 - f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) != 0 else 0 - f1_beta = (1 + self.beta**2) * precision * recall / (self.beta**2 * precision + recall) if (precision + recall) != 0 else 0 - correction_precision = correction_true_positive / (true_positive + false_positive) if true_positive != 0 else 0 - correction_recall = correction_true_positive / (true_positive + false_negative) if correction_true_positive !=0 else 0 - - # Mean results for the entire batch + precision = ( + true_positive / (true_positive + false_positive) + if (true_positive + false_positive) != 0 + else 0 + ) + recall = ( + true_positive / (true_positive + false_negative) + if (true_positive + false_negative) != 0 + else 0 + ) + f1 = ( + 2 * (precision * recall) / (precision + recall) + if (precision + recall) != 0 + else 0 + ) + f1_beta = ( + (1 + self.beta**2) + * precision + * recall + / (self.beta**2 * precision + recall) + if (precision + recall) != 0 + else 0 + ) + correction_precision = ( + correction_true_positive / (true_positive + false_positive) + if true_positive != 0 + else 0 + ) + correction_recall = ( + correction_true_positive / (true_positive + false_negative) + if correction_true_positive != 0 + else 0 + ) + + # Mean results for the entire batch results = { "correction_precision": correction_precision, "correction_recall": correction_recall, @@ -193,7 +222,7 @@ def evaluate(self, predictions: List[str], references: List[str]) -> Mapping: "f1": f1, "f1_beta": f1_beta, "beta": self.beta, - "drop_count": drop_count + "drop_count": drop_count, } LOGGER.info(f"Evaluation metrics: {results}") return results @@ -222,14 +251,14 @@ def sequence_alignment( Returns: List[Tuple]: List of token pairs. - Example 1: + Example 1: ``` tokens_1 = [791, 8415, 4502, 389, 279, 282, 1425, 13] tokens_2 = [791, 8415, 374, 389, 279, 38681, 13] alignment = [(791, 791), (8415, 8415), (4502, 374), (389, 389), (279, 279), (282, None), (1425, 38681), (13, 13)] ``` - + Example 2: ``` tokens_1 = [54, 51, 23, 165, 415, 61, 561] @@ -239,7 +268,9 @@ def sequence_alignment( ``` """ # Initialize matrix - matrix = [[i * gap_penalty] + [0] * (len(tokens_2)) for i in range(len(tokens_1) + 1)] + matrix = [ + [i * gap_penalty] + [0] * (len(tokens_2)) for i in range(len(tokens_1) + 1) + ] matrix[0] = [(j * gap_penalty) for j in range(len(tokens_2) + 1)] # Fill in the matrix for i in range(1, len(tokens_1) + 1): @@ -268,7 +299,6 @@ def sequence_alignment( elif matrix[i][j] == matrix[i][j - 1] + gap_penalty: alignment.append((None, tokens_2[j - 1])) # Mark insertion j -= 1 - # Handle remaining elements if any while i > 0: @@ -282,13 +312,13 @@ def sequence_alignment( @staticmethod def convert_pairs_into_sparse(pairs: List[Tuple]) -> List[int]: - """Convert alignement pairs/tuples into a sparse vector. - If there is a mismatch between tokens from the same pair, it is considered as a modification (=1). - + """Convert alignement pairs/tuples into a sparse vector. + If there is a mismatch between tokens from the same pair, it is considered as a modification (=1). + Example: ``` pairs = [(791, 791), (8415, 8415), (4502, 374), (389, 389), (279, 279), (282, None), (1425, 38681), (13, 13)] - sparse_pairs = [0, 0, 1, 0, 0, 1, 1, 0] + sparse_pairs = [0, 0, 1, 0, 0, 1, 1, 0] ``` Args: pairs (List[Tuple]): Iterable of token pairs from the Sequence alignment algorithm. @@ -297,15 +327,13 @@ def convert_pairs_into_sparse(pairs: List[Tuple]) -> List[int]: (List[int]): Sparse vectors. """ return [0 if i == j else 1 for i, j in pairs] - + @staticmethod def align_pairs( - pairs1: List[Tuple], - pairs2: List[Tuple], - neutral_pair: Tuple = (None, None) + pairs1: List[Tuple], pairs2: List[Tuple], neutral_pair: Tuple = (None, None) ) -> Tuple[List[Tuple], List[Tuple]]: """SInce we compare Pairs between the Reference and the Prediction, it's possible that tokens were added or deleted - in one but not in the other. This leads to a misalignment between pairs of tokens required for calculating + in one but not in the other. This leads to a misalignment between pairs of tokens required for calculating Precision and Recall. For this reason, we add a "neutral" pair of tokens for each gap in the opposite list of pairs. This "neutral" pair @@ -316,13 +344,13 @@ def align_pairs( Before: Orig-Ref pairs: [(400, 400), (350, 350), (20, 18), (21, 40), (None, 51), (23, 23)] Orig-Pred pairs: [(400, 400), (350, 350), (None, 800), (20, 18), (21, 40), (23, 80)] - + sparse Ref: [0, 0, 1, 1, 1, 0] sparse Pred: [0, 0, 1, 1, 1, 1] After: Orig-Ref pairs: [(400, 400), (350, 350), (None, None), (20, 18), (21, 40), (None, 51) (23, 23)] - Orig-Pred pairs: [(400, 400), (350, 350), (None, 800), (20, 18), (21, 40), (None, None), (23, 80)] + Orig-Pred pairs: [(400, 400), (350, 350), (None, 800), (20, 18), (21, 40), (None, None), (23, 80)] sparse Ref: [0, 0, 0, 1, 1, 1, 0] sparse Pred: [0, 0, 1, 1, 1, 0, 1] @@ -333,7 +361,7 @@ def align_pairs( neutral_pairs (Tuple, optional): Pair to insert for alignment. Defaults to (None, None). Returns: - Tuple[List[Tuple], List[Tuple]]: Aligned list of pairs. + Tuple[List[Tuple], List[Tuple]]: Aligned list of pairs. """ # Since we insert into the list, we create copies to avoid the global modification pairs1_bis, pairs2_bis = pairs1.copy(), pairs2.copy() @@ -349,12 +377,12 @@ def align_pairs( def get_correction_true_positives( self, ref_pairs: List[Tuple], - pred_pairs: List[Tuple], + pred_pairs: List[Tuple], ) -> float: """Correction true positives corresponding to the precision of the model the predict the correct token. Note: - We consider only tokens that were modified by the model & were supposed to be modified. + We consider only tokens that were modified by the model & were supposed to be modified. It means that if the model missed a token correction, or added one, it is not considered in the correction precision calculation in case the token wasn"t supposed to be corrected. @@ -366,13 +394,13 @@ def get_correction_true_positives( float: precision of picking the right token. """ sparse_ref_pairs, sparse_pred_pairs = ( - self.convert_pairs_into_sparse(ref_pairs), - self.convert_pairs_into_sparse(pred_pairs) + self.convert_pairs_into_sparse(ref_pairs), + self.convert_pairs_into_sparse(pred_pairs), ) true_positives = np.multiply(sparse_ref_pairs, sparse_pred_pairs) correction_true_positives = [ - int(ref_pairs[idx][1] == pred_pairs[idx][1]) - for idx, tp in enumerate(true_positives) + int(ref_pairs[idx][1] == pred_pairs[idx][1]) + for idx, tp in enumerate(true_positives) if tp == 1 ] return correction_true_positives @@ -386,18 +414,22 @@ def normalize(texts: List[str]) -> List[str]: Returns: (Tuple) Processed texts """ + def process(text: str) -> str: - text = text.lower() # Lowercase - text = " ".join([token.strip() for token in text.split()]) # Normalize whitespaces - text = text.replace("œ", "oe") # Oeuf, Boeuf, ... - text = text.replace("ï", "i") # Maïs, ... - text = text.replace("â", "a") - text = text.replace("flavour", "flavor") # US/UK + text = text.lower() # Lowercase + text = " ".join( + [token.strip() for token in text.split()] + ) # Normalize whitespaces + text = text.replace("œ", "oe") # Oeuf, Boeuf, ... + text = text.replace("ï", "i") # Maïs, ... + text = text.replace("â", "a") + text = text.replace("flavour", "flavor") # US/UK text = text.replace("colour", "color") text = text.replace("pasteurized", "pasteurised") - text = unidecode(text) # Remove accents - text = text.replace("\n", "") # Counted as an error by evaluator + text = unidecode(text) # Remove accents + text = text.replace("\n", "") # Counted as an error by evaluator return text + return [process(text) for text in texts] @@ -405,25 +437,26 @@ def process(text: str) -> str: # DEBUG ORGINALS = [ - "cacao maigre en Sucre poudre 20% - émulsifiant : léci - thines de tournesol - carbo - nate de magnésium", - "Ananas, Ananassaft, Säuerungs - mittel: Citronensäure", - "_Cacahuetes_ con cáscara tostado. _Trazas de frutos de cáscara_.", - "The cas is on the firdge" -] + "cacao maigre en Sucre poudre 20% - émulsifiant : léci - thines de tournesol - carbo - nate de magnésium", + "Ananas, Ananassaft, Säuerungs - mittel: Citronensäure", + "_Cacahuetes_ con cáscara tostado. _Trazas de frutos de cáscara_.", + "The cas is on the firdge", + ] REFERENCES = [ - "cacao maigre en Sucre poudre 20% - émulsifiant : lécithines de tournesol - carbonate de magnésium", - "Ananas, Ananassaft, Säuerungsmittel: Citronensäure", - "_Cacahuetes_ con cáscara tostado. Trazas de frutos de cáscara.", - "The cat is in the fridge" -] + "cacao maigre en Sucre poudre 20% - émulsifiant : lécithines de tournesol - carbonate de magnésium", + "Ananas, Ananassaft, Säuerungsmittel: Citronensäure", + "_Cacahuetes_ con cáscara tostado. Trazas de frutos de cáscara.", + "The cat is in the fridge", + ] PREDICTIONS = [ - "cacao maigre en Sucre pdre 20% - émulsifiant : lécithines de tournesol - carbona de magnésium", - "Ananas, Säuerungsmittel: Citronensäure", - "Cacahuetes con cáscara tostado. _Trazas de frutos de cáscara_.", - "The big cat is in the fridge" - -] + "cacao maigre en Sucre pdre 20% - émulsifiant : lécithines de tournesol - carbona de magnésium", + "Ananas, Säuerungsmittel: Citronensäure", + "Cacahuetes con cáscara tostado. _Trazas de frutos de cáscara_.", + "The big cat is in the fridge", + ] spellcheck_evaluator = SpellcheckEvaluator(originals=ORGINALS) - results = spellcheck_evaluator.evaluate(predictions=PREDICTIONS, references=REFERENCES) + results = spellcheck_evaluator.evaluate( + predictions=PREDICTIONS, references=REFERENCES + ) print(results) diff --git a/spellcheck/src/spellcheck/model.py b/spellcheck/src/spellcheck/model.py index 3c67a3a1..e59d3e43 100644 --- a/spellcheck/src/spellcheck/model.py +++ b/spellcheck/src/spellcheck/model.py @@ -1,4 +1,5 @@ """Spellcheck models""" + import os from abc import ABC, abstractmethod from typing import Literal @@ -32,6 +33,7 @@ class OpenAIChatCompletion(BaseModel): temperature (float, optional): _description_. Defaults to 0. max_tokens (int, optional): _description_. Defaults to 512. """ + def __init__( self, prompt_template: str, @@ -52,20 +54,22 @@ def __init__( self.max_tokens = max_tokens def generate(self, text: str) -> str: - messages = self.messages + [{"role": "user", "content": self.prompt_template.format(text)}] + messages = self.messages + [ + {"role": "user", "content": self.prompt_template.format(text)} + ] response = self.client.chat.completions.create( model=self.model_name, messages=messages, temperature=self.temperature, - max_tokens=self.max_tokens + max_tokens=self.max_tokens, ) output_text = response.choices[0].message.content return output_text.strip() class AnthropicChatCompletion(BaseModel): - """LLMs from Anthropic - """ + """LLMs from Anthropic""" + def __init__( self, prompt_template: str, @@ -73,7 +77,7 @@ def __init__( model_name: Literal[ "claude-3-haiku-20240307", "claude-3-sonnet-20240229", - "claude-3-opus-20240229" + "claude-3-opus-20240229", ], temperature: float = 0, max_tokens: int = 512, @@ -84,17 +88,17 @@ def __init__( self.model_name = model_name self.temperature = temperature self.max_tokens = max_tokens - + def generate(self, text: str) -> str: message = self.client.messages.create( model=self.model_name, system=self.system_prompt, messages=[{"role": "user", "content": self.prompt_template.format(text)}], temperature=self.temperature, - max_tokens=self.max_tokens + max_tokens=self.max_tokens, ) return message.content[0].text - + class RulesBasedModel(BaseModel): """Rules-based methods.""" @@ -102,7 +106,7 @@ class RulesBasedModel(BaseModel): @staticmethod def generate(text: str) -> str: return text.replace("léci - thine", "lécithine") - + class GeminiModel(BaseModel): """Google Gemini.""" @@ -114,21 +118,21 @@ def __init__( model_name: Literal[ "gemini-1.0-pro-002", "gemini-1.5-flash-preview-0514", - "gemini-1.5-pro-preview-0409" - ], + "gemini-1.5-pro-preview-0409", + ], temperature: float = 0, - max_tokens: int= 512, + max_tokens: int = 512, project_id: str = "robotoff", - location: str = "us-central1" + location: str = "us-central1", ) -> None: self.prompt_template = prompt_template - # Init model + # Init model vertexai.init(project=project_id, location=location) generation_config = { "temperature": temperature, "max_output_tokens": max_tokens, - "response_mime_type": "text/plain" + "response_mime_type": "text/plain", } self.model = GenerativeModel( model_name=model_name, @@ -153,16 +157,16 @@ class LLMInferenceEndpoint(BaseModel): """Open-Source LLM deployed on Hugging Face Inference Endpoints.""" def __init__( - self, - prompt_template: str, - system_prompt: str, - temperature: int = 0, - max_tokens: int = 512 - ) -> None: + self, + prompt_template: str, + system_prompt: str, + temperature: int = 0, + max_tokens: int = 512, + ) -> None: self.client = OpenAI( base_url=os.getenv("HF_INFERENCE_ENDPOINT_URL") + "/v1/", - api_key=os.getenv("HF_TOKEN") - ) # With TGI, we can use the existing OpenAI API to run our LLM + api_key=os.getenv("HF_TOKEN"), + ) # With TGI, we can use the existing OpenAI API to run our LLM self.prompt_template = prompt_template self.system_prompt = system_prompt self.temperature = temperature @@ -172,11 +176,14 @@ def generate(self, text: str) -> str: message = self.client.chat.completions.create( model="tgi", messages=[ - {"role": "user", "content": self.system_prompt + "\n\n" + self.prompt_template.format(text)}, + { + "role": "user", + "content": self.system_prompt + + "\n\n" + + self.prompt_template.format(text), + }, ], temperature=self.temperature, max_tokens=self.max_tokens, ) return message.choices[0].message.content - - diff --git a/spellcheck/src/spellcheck/processing.py b/spellcheck/src/spellcheck/processing.py index a1fe59fb..e54d7ad5 100644 --- a/spellcheck/src/spellcheck/processing.py +++ b/spellcheck/src/spellcheck/processing.py @@ -11,11 +11,7 @@ class DataProcessor: """Class methods to process datasets for the Spellcheck.""" @classmethod - def align_oe( - cls, - references: Iterable[str], - texts: Iterable[str] - ) -> Iterable[str]: + def align_oe(cls, references: Iterable[str], texts: Iterable[str]) -> Iterable[str]: """Align "oe" - "œ character between the reference text and the target. Examples: @@ -27,11 +23,8 @@ def align_oe( references (Iterable[str]): Text references. text (Iterable[str]): Texts to align with references. """ - return [ - cls._align_oe(ref, text) - for ref, text in zip(references, texts) - ] - + return [cls._align_oe(ref, text) for ref, text in zip(references, texts)] + def _align_oe(reference: str, text: str) -> str: """ Args: @@ -41,8 +34,8 @@ def _align_oe(reference: str, text: str) -> str: Returns: str: Text with identical oe. """ - pattern = re.compile(r"oe|œ") # OR - + pattern = re.compile(r"oe|œ") # OR + ref_matches = pattern.finditer(reference) text_matches = pattern.finditer(text) @@ -65,11 +58,11 @@ def replace_match(match: re.Match) -> str: ref_match = matches_dict.get(match.span()) # Match.group() return the match as a string # Possibility a oe was not found in the reference, leading to dict.get()= None - if ref_match and ref_match != match.group(): - return ref_match # Reference + if ref_match and ref_match != match.group(): + return ref_match # Reference else: return match.group() - + modified_text = pattern.sub(replace_match, text) return modified_text @@ -136,7 +129,10 @@ def replace_match(match: re.Match) -> str: # Compose the new match with the space between group 1 and group 3 text_whitespace = match.group(2) ref_whitespace = matches_dict.get(match.span()) - if ref_whitespace in ["", " "] and text_whitespace in ["", " "]: # Could be empty string "" or whitespace " " + if ref_whitespace in ["", " "] and text_whitespace in [ + "", + " ", + ]: # Could be empty string "" or whitespace " " return match.group(1) + ref_whitespace + match.group(3) else: LOGGER.debug( diff --git a/spellcheck/src/spellcheck/prompt.py b/spellcheck/src/spellcheck/prompt.py index 0e1e50e6..178dd4a5 100644 --- a/spellcheck/src/spellcheck/prompt.py +++ b/spellcheck/src/spellcheck/prompt.py @@ -1,6 +1,8 @@ """Prompts for LLMs""" + from dataclasses import dataclass + @dataclass class SystemPrompt: """Class containing system prompt used in Chat Completion""" @@ -46,9 +48,10 @@ class SystemPrompt: Cacao*, azúcar de coco* (30%), manteca de cacao, frambuesa deshidratada (1 %), açai deshidratado* (0,5 % ) """ + @dataclass class Prompt: """Class containing LLM prompts""" spellcheck_prompt_template = """Remember to let the text as unchanged as possible. Focus on the guidelines.\n\n###List of ingredients:\n{}\n\n###Corrected list of ingredients:\n""" - claude_spellcheck_prompt_template = """Just print the corrected list of ingredients and nothing else!\n###List of ingredients:\n{}\n\n###Corrected list of ingredients:\n""" \ No newline at end of file + claude_spellcheck_prompt_template = """Just print the corrected list of ingredients and nothing else!\n###List of ingredients:\n{}\n\n###Corrected list of ingredients:\n""" diff --git a/spellcheck/src/spellcheck/spellcheck.py b/spellcheck/src/spellcheck/spellcheck.py index 9478fe62..5255c2ea 100644 --- a/spellcheck/src/spellcheck/spellcheck.py +++ b/spellcheck/src/spellcheck/spellcheck.py @@ -1,4 +1,5 @@ """Spellcheck module.""" + import logging from spellcheck.model import BaseModel @@ -6,16 +7,17 @@ class Spellcheck: """Spellcheck module to correct typos and errors in lists of ingredients. - + Args: model (BaseModel): model used for correcting list of ingredients. """ + def __init__(self, model: BaseModel) -> None: self.model = model def correct(self, text: str) -> str: """Correct list of ingredients - + Args: list_of_ingredients (str): List of ingredients to correct. diff --git a/spellcheck/src/spellcheck/training/configs.py b/spellcheck/src/spellcheck/training/configs.py index e343d912..24881ffe 100644 --- a/spellcheck/src/spellcheck/training/configs.py +++ b/spellcheck/src/spellcheck/training/configs.py @@ -8,7 +8,9 @@ class DataProcessingConfig: class SFTDataProcessingConfig(DataProcessingConfig): - instruction_template: str = "###Correct the list of ingredients:\n{}\n\n###Correcton:\n" #TODO: add jinja instruction + instruction_template: str = ( + "###Correct the list of ingredients:\n{}\n\n###Correcton:\n" # TODO: add jinja instruction + ) @dataclass @@ -38,6 +40,7 @@ class SavingConfig: merge_weights: Whether to merge the adapter and model weights with. max_shard_size: Maximum shard size. """ + merge_weights: bool = field(default=False) max_shard_size: str = field(default="2GB") diff --git a/spellcheck/src/spellcheck/training/trainer.py b/spellcheck/src/spellcheck/training/trainer.py index 34159342..6bc7004d 100644 --- a/spellcheck/src/spellcheck/training/trainer.py +++ b/spellcheck/src/spellcheck/training/trainer.py @@ -1,17 +1,7 @@ import os from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import ( - Optional, - Mapping, - List, - Union, - Any, - Iterable, - Tuple, - Type, - Dict -) +from typing import Optional, Mapping, List, Union, Any, Iterable, Tuple, Type, Dict from functools import partial from tqdm import tqdm @@ -46,29 +36,29 @@ @dataclass class Datasets: - """Dataclass to store training and evaluation datasets, raw or processed. - """ + """Dataclass to store training and evaluation datasets, raw or processed.""" + training_dataset: Dataset evaluation_dataset: Optional[Dataset] = None class DataProcessor(ABC, BaseModel): """Processing class to transform datasets for the model training and evaluation. - + The class is designed to inherite the data processing job adapted to the training algorithm, such as SFT, DPO, Instruction-tuning, and so on... """ data_processsing_config: DataProcessingConfig - + @abstractmethod def _process_fn( - self, + self, element: Mapping[str, Union[Any, List]], text_feature: str, label_feature: str, ) -> Mapping[str, Union[Any, List]]: """Processing function used within the Dataset.map() method from the 'datasets' library. - + The control the behavior of this function, one should create a new class that inherates from DataProcessor and build its own _process_fn() method. The latest is then used in the process() method from the base model. @@ -81,12 +71,12 @@ def _process_fn( Mapping[str, Union[Any, List]]: Processed elements. """ raise NotImplementedError - + def process_datasets( - self, + self, datasets: Datasets, training_data_features: TrainingDataFeatures, - evaluation_data_features: Optional[EvaluationDataFeatures] = None + evaluation_data_features: Optional[EvaluationDataFeatures] = None, ) -> Datasets: """Performs datasets processing. @@ -106,7 +96,9 @@ def process_datasets( ) # Evaluation dataset if not datasets.evaluation_dataset and evaluation_data_features: - LOGGER.warning(f"Evaluation processing features provided but no evaluation dataset was provided. Datasets dataclass currently provided: {datasets}") + LOGGER.warning( + f"Evaluation processing features provided but no evaluation dataset was provided. Datasets dataclass currently provided: {datasets}" + ) elif datasets.evaluation_dataset: processed_evaluation_dataset = self._map_dataset( dataset=datasets.evaluation_dataset, @@ -115,18 +107,13 @@ def process_datasets( ) return Datasets( training_dataset=processed_training_dataset, - evaluation_dataset=processed_evaluation_dataset + evaluation_dataset=processed_evaluation_dataset, ) # If only train dataset - return Datasets( - training_dataset=processed_training_dataset - ) - + return Datasets(training_dataset=processed_training_dataset) + def _map_dataset( - self, - dataset: Dataset, - text_feature: str, - label_feature: str + self, dataset: Dataset, text_feature: str, label_feature: str ) -> Dataset: """Method using map() method from the datasets library with additional arguments. @@ -147,39 +134,38 @@ def _map_dataset( batched=self.data_processsing_config.batched, remove_columns=dataset.column_names, ) - + @abstractmethod def process_texts(self, texts: Iterable[str]) -> Iterable[str]: """Text processing abstract method used during inference. Args: - texts (Iterable[str]): Batch of texts to process. + texts (Iterable[str]): Batch of texts to process. Returns: Iterable[str]: Processed texts. """ raise NotImplementedError - + class SFTDataProcessor(DataProcessor): - """Data processing engine for Supervised Fine Tuning training. - """ + """Data processing engine for Supervised Fine Tuning training.""" data_processsing_config: SFTDataProcessingConfig def _process_fn( self, - element: Mapping[str, Union[Any, List]], - text_feature: str, - label_feature: str + element: Mapping[str, Union[Any, List]], + text_feature: str, + label_feature: str, ) -> Mapping[str, Union[Any, List]]: """Prepare data for Instruction fine-tuning using the SFT Trainer. - + The latest expects the feature column 'text'. The text input and label are concatenated into one instruction-prompt using the 'instruction_template'. Args: - element (Dict[str, Union[Any, List]]): Element during dataset mapping. + element (Dict[str, Union[Any, List]]): Element during dataset mapping. text_feature (str): Text column name in the dataset. label_feature (str): Label column name in the dataset @@ -190,9 +176,9 @@ def _process_fn( return {"text": instruction + element[label_feature]} def _prepare_instruction(self, text: str) -> str: - """Prepare instruction based on the instruction-template. + """Prepare instruction based on the instruction-template. This function is primordial for the training step, but also during inference. - + Args: text (str): Text to process. @@ -205,23 +191,23 @@ def process_texts(self, texts: Iterable[str]) -> Iterable[str]: """Text processing method for SFTDataProcessor used during inference. Args: - texts (Iterable[str]): Batch of texts to process. + texts (Iterable[str]): Batch of texts to process. Returns: Iterable[str]: Processed texts. """ return [self._prepare_instruction(text) for text in texts] - + class SavingProcessor(ABC, BaseModel): - """Saving processor abstract class after training. - """ + """Saving processor abstract class after training.""" + model_config = ConfigDict(arbitrary_types_allowed=True) @abstractmethod def save_trainer(self, trainer: Trainer): raise NotImplementedError - + class LoRASavingProcessor(SavingProcessor): """Saving processor for QLoRA training. Save adapters with our without the model. @@ -230,13 +216,14 @@ class LoRASavingProcessor(SavingProcessor): output_dir (str): Directory to save the model. saving_config (SavingConfig): Saving configuration. """ + output_dir: str saving_config: SavingConfig - + def save_trainer(self, trainer: Trainer) -> None: """Use trainer instance after training to save tokenizer and fine-tuned model. - - Check these links to know more about saving LoRA adapters after training: + + Check these links to know more about saving LoRA adapters after training: * https://www.philschmid.de/fine-tune-llms-in-2024-with-trl * https://github.com/philschmid/llm-sagemaker-sample/blob/main/scripts/run_qlora.py @@ -255,7 +242,7 @@ def save_trainer(self, trainer: Trainer) -> None: # Load PEFT model in fp16. It uses the saved LoRA adapters and load the pretrained model for the HF hub. LOGGER.info("Load PEFT model.") - torch_dtype = torch.bfloat16 #TODO: add bf16 to arguments + torch_dtype = torch.bfloat16 # TODO: add bf16 to arguments model = AutoPeftModelForCausalLM.from_pretrained( self.output_dir, low_cpu_mem_usage=True, @@ -265,23 +252,25 @@ def save_trainer(self, trainer: Trainer) -> None: model = model.merge_and_unload() # Save merged model model.save_pretrained( - self.output_dir, - safe_serialization=True, - max_shard_size=self.saving_config.max_shard_size + self.output_dir, + safe_serialization=True, + max_shard_size=self.saving_config.max_shard_size, ) LOGGER.info("Model merged and saved succesfully.") del model - # Remove adapters from the directory. The reason being the model.from_pretrained() method will load adapters instead of the merged model. + # Remove adapters from the directory. The reason being the model.from_pretrained() method will load adapters instead of the merged model. try: os.remove(os.path.join(self.output_dir, "adapter_config.json")) os.remove(os.path.join(self.output_dir, "adapter_model.safetensors")) except Exception as e: - LOGGER.warning(f"Something went wrong with trying to remove adapters files for the training directory. Error: {e}") + LOGGER.warning( + f"Something went wrong with trying to remove adapters files for the training directory. Error: {e}" + ) else: # Save adapters only - trainer.model.save_pretrained(self.output_dir) + trainer.model.save_pretrained(self.output_dir) class InferenceProcessor(ABC, BaseModel): @@ -294,6 +283,7 @@ class InferenceProcessor(ABC, BaseModel): data_processor (Optional[DataProcessor]): Data processor instance. device (str): Device to use for inference. """ + model_config = ConfigDict(arbitrary_types_allowed=True) tokenizer: PreTrainedTokenizerBase @@ -308,13 +298,13 @@ def inference(self): @classmethod def load_pretrained( - cls, + cls, model_dir: str, model_config: ModelConfig, data_processor: DataProcessor, - inference_config: InferenceConfig + inference_config: InferenceConfig, ) -> None: - """Class method to load model and tokenizer for inference. + """Class method to load model and tokenizer for inference. Args: model_dir (str): Model directory. @@ -327,7 +317,7 @@ def load_pretrained( tokenizer.pad_token = tokenizer.eos_token tokenizer.pad_token_id = tokenizer.eos_token_id - torch_dtype = torch.bfloat16 #: add bfloat16 to arguments + torch_dtype = torch.bfloat16 #: add bfloat16 to arguments model = AutoModelForCausalLM.from_pretrained( model_dir, device_map=model_config.device_map, @@ -339,12 +329,12 @@ def load_pretrained( tokenizer=tokenizer, model=model, data_processor=data_processor, - inference_config=inference_config + inference_config=inference_config, ) - + def _batch_process(self, lst: Iterable, batch_size: int = 1): """Batch inputs. - + Args: lst (Iterable): List of inputs. """ @@ -353,35 +343,40 @@ def _batch_process(self, lst: Iterable, batch_size: int = 1): class TextGenerationInference(InferenceProcessor): - """Inference for Text Generation models. - """ + """Inference for Text Generation models.""" def inference( - self, - texts: Iterable[str], + self, + texts: Iterable[str], ) -> Iterable[str]: """Perform text generation. - + Args: texts (Iterable[str]): Batch of texts to process.""" predictions = [] processed_texts = self.data_processor.process_texts(texts) if self.inference_config.batch_size > 1: - processed_texts = self._batch_process(processed_texts, batch_size=self.inference_config.batch_size) - + processed_texts = self._batch_process( + processed_texts, batch_size=self.inference_config.batch_size + ) + for text_batch in tqdm( - processed_texts, - total=len(texts), - desc="Prediction" if self.inference_config.batch_size == 1 else f"Prediction in batch: batch_size = {self.inference_config.batch_size == 1}" + processed_texts, + total=len(texts), + desc=( + "Prediction" + if self.inference_config.batch_size == 1 + else f"Prediction in batch: batch_size = {self.inference_config.batch_size == 1}" + ), ): encodings = self.tokenizer( - text_batch, - add_special_tokens=True, + text_batch, + add_special_tokens=True, return_tensors="pt", - padding="longest", # In batch, required padding strategy between + padding="longest", # In batch, required padding strategy between ) - encodings = {k: v.to(self.device) for k,v in encodings.items()} + encodings = {k: v.to(self.device) for k, v in encodings.items()} pred_encodings = self.model.generate( **encodings, do_sample=False, @@ -395,14 +390,13 @@ def inference( predictions.extend(prediction_batch) return predictions - def _post_process( - self, - encodings: Mapping, - text_batch: Iterable[str] - ) -> List[str]: + def _post_process(self, encodings: Mapping, text_batch: Iterable[str]) -> List[str]: """""" - predictions = self.tokenizer.batch_decode(encodings, skip_special_tokens=True) - return [prediction[len(text):].strip() for prediction, text in zip(predictions, text_batch)] + predictions = self.tokenizer.batch_decode(encodings, skip_special_tokens=True) + return [ + prediction[len(text) :].strip() + for prediction, text in zip(predictions, text_batch) + ] class EvaluationProcessor(BaseModel): @@ -417,75 +411,81 @@ class EvaluationProcessor(BaseModel): evaluator: Optional[SpellcheckEvaluator] = Field(default=None, init=False) def model_post_init(self, __context: Any): - """Prepare evaluator during post_init: + """Prepare evaluator during post_init: https://docs.pydantic.dev/latest/api/base_model/#pydantic.BaseModel.model_post_init """ orginals, _ = self._prepare_data() self.evaluator = self.evaluator_type(originals=orginals) - + def evaluate(self, save_predictions_path: Optional[str]) -> Dict[str, float]: - """ - """ + """ """ # Load texts originals, references = self._prepare_data() # Predictions predictions = self.inference_processor.inference(texts=originals) # Evaluation - metrics = self.evaluator.evaluate(predictions=predictions, references=references) + metrics = self.evaluator.evaluate( + predictions=predictions, references=references + ) LOGGER.info(f"Evaluation metrics: {metrics}") # Save predictions prediction_dataset = self.evaluation_dataset.add_column( - name="prediction", - column=predictions + name="prediction", column=predictions ) if save_predictions_path: LOGGER.info(f"Predictions are saved in: {save_predictions_path}") prediction_dataset.save_to_disk(save_predictions_path) return metrics - + def _prepare_data(self) -> Tuple[Iterable[str], Iterable[str]]: """""" originals = self.evaluation_dataset[self.evaluation_features.eval_text_feature] - references = self.evaluation_dataset[self.evaluation_features.eval_label_feature] + references = self.evaluation_dataset[ + self.evaluation_features.eval_label_feature + ] return originals, references class ExperimentLogger(ABC, BaseModel): """Class to log experiment information on experiment tracker.""" + model_config = ConfigDict(arbitrary_types_allowed=True) @abstractmethod - def log(self): """Log method.""" raise NotImplementedError - + class CometExperimentLogger(ExperimentLogger): """CometML Experiment Tracker class""" - + experiment: comet_ml.Experiment workspace: str = os.getenv("COMET_WORKSPACE_NAME") project_name: str = os.getenv("COMET_PROJECT_NAME") api_key: str = os.getenv("COMET_API_KEY") - + @classmethod def load_experiment(cls): global_experiment = comet_ml.get_global_experiment() - experiment = global_experiment if global_experiment else comet_ml.Experiment( - api_key=cls.api_key, - project_name=cls.project_name, - workspace=cls.workspace, + experiment = ( + global_experiment + if global_experiment + else comet_ml.Experiment( + api_key=cls.api_key, + project_name=cls.project_name, + workspace=cls.workspace, + ) ) return cls(experiment=experiment) def log( - self, - metrics: Optional[Mapping] = None, - parameters: Optional[Mapping] = None, + self, + metrics: Optional[Mapping] = None, + parameters: Optional[Mapping] = None, model_uri: Optional[str] = None, model_name: str = "model", - tags: Optional[List[str]] = None + tags: Optional[List[str]] = None, ) -> None: """Log data in experiment tracker. @@ -512,9 +512,7 @@ def _log_parameters(self, parameters: Dict[str, Any]) -> None: self.experiment.log_parameters(parameters) def _log_model(self, model_uri: str, model_name: str) -> None: - self.experiment.log_remote_model( - model_name, model_uri, sync_mode=False - ) + self.experiment.log_remote_model(model_name, model_uri, sync_mode=False) def _log_tags(self, tags: List[str]) -> None: self.experiment.add_tags(tags) diff --git a/spellcheck/src/spellcheck/training/utils.py b/spellcheck/src/spellcheck/training/utils.py index f50b6251..90759acf 100644 --- a/spellcheck/src/spellcheck/training/utils.py +++ b/spellcheck/src/spellcheck/training/utils.py @@ -12,7 +12,7 @@ class CustomCometCallback(CometCallback): def setup(self, args, state, model): """ - Custom CometCallback to use an existing experiment during training using Trainer + Custom CometCallback to use an existing experiment during training using Trainer Modification: add get_gloabl_experiment() before creating another experiment instance in CometML """ self._initialized = True @@ -21,19 +21,33 @@ def setup(self, args, state, model): self._log_assets = True if state.is_world_process_zero: comet_mode = os.getenv("COMET_MODE", "ONLINE").upper() - experiment_kwargs = {"project_name": os.getenv("COMET_PROJECT_NAME", "huggingface")} + experiment_kwargs = { + "project_name": os.getenv("COMET_PROJECT_NAME", "huggingface") + } if comet_mode == "ONLINE": ### MODIFICATION - experiment = comet_ml.get_global_experiment() if comet_ml.get_global_experiment() else comet_ml.Experiment(**experiment_kwargs) + experiment = ( + comet_ml.get_global_experiment() + if comet_ml.get_global_experiment() + else comet_ml.Experiment(**experiment_kwargs) + ) experiment.log_other("Created from", "transformers") logger.info("Automatic Comet.ml online logging enabled") elif comet_mode == "OFFLINE": - experiment_kwargs["offline_directory"] = os.getenv("COMET_OFFLINE_DIRECTORY", "./") + experiment_kwargs["offline_directory"] = os.getenv( + "COMET_OFFLINE_DIRECTORY", "./" + ) experiment = comet_ml.OfflineExperiment(**experiment_kwargs) experiment.log_other("Created from", "transformers") - logger.info("Automatic Comet.ml offline logging enabled; use `comet upload` when finished") + logger.info( + "Automatic Comet.ml offline logging enabled; use `comet upload` when finished" + ) if experiment is not None: experiment._set_model_graph(model, framework="transformers") - experiment._log_parameters(args, prefix="args/", framework="transformers") + experiment._log_parameters( + args, prefix="args/", framework="transformers" + ) if hasattr(model, "config"): - experiment._log_parameters(model.config, prefix="config/", framework="transformers") \ No newline at end of file + experiment._log_parameters( + model.config, prefix="config/", framework="transformers" + ) diff --git a/spellcheck/src/spellcheck/utils.py b/spellcheck/src/spellcheck/utils.py index 5ca7af2e..f0a4a789 100644 --- a/spellcheck/src/spellcheck/utils.py +++ b/spellcheck/src/spellcheck/utils.py @@ -35,10 +35,10 @@ def get_logger(level: Optional[str] = None) -> logging.Logger: def show_diff( - original_text: str, + original_text: str, corrected_text: str, color: Literal["yellow", "orange", "red"] = "yellow", - deleted_element: str = ArgillaConfig.deleted_element + deleted_element: str = ArgillaConfig.deleted_element, ) -> str: """Unify operations between two compared strings seqm is a difflib.SequenceMatcher instance whose a & b are strings @@ -58,20 +58,32 @@ def show_diff( html_color = ArgillaConfig.html_colors[color] # Check if the process was not done if "" not in corrected_text: - seqm = difflib.SequenceMatcher(None, original_text, corrected_text) - output= [] + seqm = difflib.SequenceMatcher(None, original_text, corrected_text) + output = [] for opcode, a0, a1, b0, b1 in seqm.get_opcodes(): - if opcode == 'equal': + if opcode == "equal": output.append(seqm.a[a0:a1]) - elif opcode == 'insert': - output.append(f"" + seqm.b[b0:b1] + "") - elif opcode == 'delete': - output.append(f"" + deleted_element + "") - elif opcode == 'replace': - output.append(f"" + seqm.b[b0:b1] + "") + elif opcode == "insert": + output.append( + f"" + + seqm.b[b0:b1] + + "" + ) + elif opcode == "delete": + output.append( + f"" + + deleted_element + + "" + ) + elif opcode == "replace": + output.append( + f"" + + seqm.b[b0:b1] + + "" + ) else: raise RuntimeError("unexpected opcode") - return ''.join(output) + return "".join(output) else: return corrected_text @@ -94,11 +106,15 @@ def load_jsonl(path: Path) -> Iterable[Mapping]: def timer(fn): """Decorator to track function duration.""" + def wrapper(*args, **kwargs): logger = get_logger() timestamp = time.time() logger.info(f"Start {fn.__name__}.") output = fn(*args, **kwargs) - logger.info(f"The function {fn.__name__} took {round(time.time() - timestamp)} to finish.") + logger.info( + f"The function {fn.__name__} took {round(time.time() - timestamp)} to finish." + ) return output - return wrapper \ No newline at end of file + + return wrapper diff --git a/spellcheck/tests/test_argilla.py b/spellcheck/tests/test_argilla.py index 076fc0a8..c1cffd08 100644 --- a/spellcheck/tests/test_argilla.py +++ b/spellcheck/tests/test_argilla.py @@ -7,7 +7,7 @@ BenchmarkArgilla, BenchmarkEvaluationArgilla, IngredientsCompleteEvaluationArgilla, - TrainingDataArgilla + TrainingDataArgilla, ) from spellcheck.argilla.extraction import SpellcheckExtraction @@ -16,15 +16,21 @@ REFERENCES = ["text1", "text", "text3"] PREDICTIONS = ["text1", "text2", "text"] METADATA = [ - {"lang": "fr", "code": 123456789}, - {"lang": "en", "code": 123456789}, - {"lang": "es", "code": 123456789} + {"lang": "fr", "code": 123456789}, + {"lang": "en", "code": 123456789}, + {"lang": "es", "code": 123456789}, ] ARGILLA_EXTRACTION_EXAMPLE_1 = { "url": "https://world.openfoodfacts.org/product/5942262001416", "original": "water:snow", - "reference": [{"user_id": "dfb71753-1187-45e1-8006-629bef2b49e0", "value": "water:snow", "status": "discarded"}], + "reference": [ + { + "user_id": "dfb71753-1187-45e1-8006-629bef2b49e0", + "value": "water:snow", + "status": "discarded", + } + ], "reference-suggestion": "water:snow", "reference-suggestion-metadata": {"type": None, "score": None, "agent": None}, "is_truncated": [], @@ -50,14 +56,20 @@ "original": "Ananas, Ananassaft, Säuerungs - mittel: Citronensäure", "reference": [ { - "user_id": "dfb71753-1187-45e1-8006-629bef2b49e0", - "value": "Ananas, Ananassaft, Säuerungsmittel: Citronensäure", - "status": "submitted" + "user_id": "dfb71753-1187-45e1-8006-629bef2b49e0", + "value": "Ananas, Ananassaft, Säuerungsmittel: Citronensäure", + "status": "submitted", } ], "reference-suggestion": "Ananas, Ananassaft, Säuerungsmittel: Citronensäure", "reference-suggestion-metadata": {"type": None, "score": None, "agent": None}, - "is_truncated": [{"user_id": "dfb71753-1187-45e1-8006-629bef2b49e0", "value": "NO", "status": "submitted"}], + "is_truncated": [ + { + "user_id": "dfb71753-1187-45e1-8006-629bef2b49e0", + "value": "NO", + "status": "submitted", + } + ], "is_truncated-suggestion": None, "is_truncated-suggestion-metadata": {"type": None, "score": None, "agent": None}, "external_id": None, @@ -71,7 +83,7 @@ def modules() -> Iterable[ArgillaModule]: BenchmarkArgilla(ORIGINALS, REFERENCES, METADATA), BenchmarkEvaluationArgilla(ORIGINALS, REFERENCES, PREDICTIONS, METADATA), IngredientsCompleteEvaluationArgilla(ORIGINALS, PREDICTIONS, METADATA), - TrainingDataArgilla(ORIGINALS, REFERENCES, METADATA) + TrainingDataArgilla(ORIGINALS, REFERENCES, METADATA), ] @@ -79,7 +91,7 @@ def modules() -> Iterable[ArgillaModule]: def evaluation_modules() -> Iterable[ArgillaModule]: return [ BenchmarkEvaluationArgilla(ORIGINALS, REFERENCES, PREDICTIONS, METADATA), - IngredientsCompleteEvaluationArgilla(ORIGINALS, PREDICTIONS, METADATA) + IngredientsCompleteEvaluationArgilla(ORIGINALS, PREDICTIONS, METADATA), ] @@ -104,26 +116,16 @@ def test_evaluation_highlights(evaluation_modules: Iterable[ArgillaModule]): @pytest.mark.parametrize( "inputs, expected", [ - ( - (ARGILLA_EXTRACTION_EXAMPLE_1, ["submitted"]), - False - ), - ( - (ARGILLA_EXTRACTION_EXAMPLE_2, ["submitted", "pending"]), - True - ), - ( - (ARGILLA_EXTRACTION_EXAMPLE_3, ["submitted"]), - True - ) - ] + ((ARGILLA_EXTRACTION_EXAMPLE_1, ["submitted"]), False), + ((ARGILLA_EXTRACTION_EXAMPLE_2, ["submitted", "pending"]), True), + ((ARGILLA_EXTRACTION_EXAMPLE_3, ["submitted"]), True), + ], ) def test_argilla_spellcheck_extraction_filter(inputs, expected): """Test Argilla dataset extraction filtering.""" element, status = inputs is_kept = SpellcheckExtraction( - dataset_name="test_dataset", - extracted_status=status + dataset_name="test_dataset", extracted_status=status )._filter_fn(element) assert is_kept == expected @@ -139,7 +141,7 @@ def test_argilla_spellcheck_extraction_filter(inputs, expected): "lang": "de", "data_origin": "labeled_data", "is_truncated": 0, - } + }, ), ( ARGILLA_EXTRACTION_EXAMPLE_2, @@ -149,15 +151,14 @@ def test_argilla_spellcheck_extraction_filter(inputs, expected): "lang": "ro", "data_origin": "50-percent-unknown", "is_truncated": 0, - } + }, ), - ] + ], ) def test_argilla_spellcheck_extraction_map(inputs, expected): """Test Argilla dataset extraction mapping.""" element = inputs extracted = SpellcheckExtraction( - dataset_name="test_dataset", - extracted_status=["submitted"] + dataset_name="test_dataset", extracted_status=["submitted"] )._map_fn(element) assert extracted == expected diff --git a/spellcheck/tests/test_evaluate.py b/spellcheck/tests/test_evaluate.py index 034ed464..be3a75fa 100644 --- a/spellcheck/tests/test_evaluate.py +++ b/spellcheck/tests/test_evaluate.py @@ -7,27 +7,26 @@ "cacao maigre en Sucre poudre 20% - émulsifiant : léci - thines de tournesol - carbo - nate de magnésium", "Ananas, Ananassaft, Säuerungs - mittel: Citronensäure", "_Cacahuetes_ con cáscara tostado. _Trazas de frutos de cáscara_.", - "The cas is on the firdge" + "The cas is on the firdge", ] REFERENCES = [ "cacao maigre en Sucre poudre 20% - émulsifiant : lécithines de tournesol - carbonate de magnésium", "Ananas, Ananassaft, Säuerungsmittel: Citronensäure", "_Cacahuetes_ con cáscara tostado. Trazas de frutos de cáscara.", - "The cat is in the fridge" + "The cat is in the fridge", ] PREDICTIONS = [ "cacao maigre en Sucre pdre 20% - émulsifiant : lécithines de tournesol - carbona de magnésium", "Ananas, Säuerungsmittel: Citronensäure", "Cacahuetes con cáscara tostado. _Trazas de frutos de cáscara_.", - "The big cat is in the fridge" - + "The big cat is in the fridge", ] # Init evaluator = SpellcheckEvaluator(originals=ORGINALS) -def test_evaluate(predictions = PREDICTIONS, references = REFERENCES): +def test_evaluate(predictions=PREDICTIONS, references=REFERENCES): """Test the overall function evaluate().""" evaluator.evaluate(predictions=predictions, references=references) assert True @@ -37,35 +36,111 @@ def test_evaluate(predictions = PREDICTIONS, references = REFERENCES): "inputs, expected", [ ( - ( - [1, 2, 3, 4, 5, 6, 7], - [1, 5, 2, 3, 4, 7] - ), - [(1, 1), (None, 5), (2, 2), (3, 3), (4, 4), (5, None), (6, None), (7, 7)] + ([1, 2, 3, 4, 5, 6, 7], [1, 5, 2, 3, 4, 7]), + [(1, 1), (None, 5), (2, 2), (3, 3), (4, 4), (5, None), (6, None), (7, 7)], ), ( - ( - [1, 2, 3], - [7, 8, 9, 1, 2, 4, 5, 3] - ), - [(None, 7), (None, 8), (None, 9), (1, 1), (2, 2), (None, 4), (None, 5), (3, 3)] + ([1, 2, 3], [7, 8, 9, 1, 2, 4, 5, 3]), + [ + (None, 7), + (None, 8), + (None, 9), + (1, 1), + (2, 2), + (None, 4), + (None, 5), + (3, 3), + ], ), ( ( - [2127, 26997, 11, 1556, 276, 56692, 728, 11, 328, 2357, 8977, 29222, 482, 48432, 301, 25, 18002, 2298, 729, 2357, 554], - [2127, 26997, 11, 1556, 276, 56692, 728, 11, 328, 2357, 8977, 2234, 3647, 96383, 25, 18002, 2298, 729, 2357, 554] + [ + 2127, + 26997, + 11, + 1556, + 276, + 56692, + 728, + 11, + 328, + 2357, + 8977, + 29222, + 482, + 48432, + 301, + 25, + 18002, + 2298, + 729, + 2357, + 554, + ], + [ + 2127, + 26997, + 11, + 1556, + 276, + 56692, + 728, + 11, + 328, + 2357, + 8977, + 2234, + 3647, + 96383, + 25, + 18002, + 2298, + 729, + 2357, + 554, + ], ), - [(2127, 2127), (26997, 26997), (11, 11), (1556, 1556), (276, 276), (56692, 56692), (728, 728), (11, 11), (328, 328), (2357, 2357), (8977, 8977), (29222, None), (482, 2234), (48432, 3647), (301, 96383), (25, 25), (18002, 18002), (2298, 2298), (729, 729), (2357, 2357), (554, 554)] + [ + (2127, 2127), + (26997, 26997), + (11, 11), + (1556, 1556), + (276, 276), + (56692, 56692), + (728, 728), + (11, 11), + (328, 328), + (2357, 2357), + (8977, 8977), + (29222, None), + (482, 2234), + (48432, 3647), + (301, 96383), + (25, 25), + (18002, 18002), + (2298, 2298), + (729, 729), + (2357, 2357), + (554, 554), + ], ), ( ( [791, 4865, 374, 389, 279, 282, 2668, 713], - [791, 8415, 374, 304, 279, 38681] + [791, 8415, 374, 304, 279, 38681], ), - [(791, 791), (4865, 8415), (374, 374), (389, 304), (279, 279), (282, None), (2668, None), (713, 38681)] - - ) - ] + [ + (791, 791), + (4865, 8415), + (374, 374), + (389, 304), + (279, 279), + (282, None), + (2668, None), + (713, 38681), + ], + ), + ], ) def test_sequence_alignment(inputs, expected): """Test sequence alignment method.""" @@ -78,95 +153,180 @@ def test_sequence_alignment(inputs, expected): "inputs, expected", [ ( + ([(1, 1), (2, 2), (None, 3), (4, 4)], [(1, 1), (2, 2), (4, 4)]), ( [(1, 1), (2, 2), (None, 3), (4, 4)], - [(1, 1), (2, 2), (4, 4)] + [(1, 1), (2, 2), (None, None), (4, 4)], ), - ( - [(1, 1), (2, 2), (None, 3), (4, 4)], - [(1, 1), (2, 2), (None, None), (4, 4)] - ) ), ( - ( - [(1, 1), (2, 2), (4, 4)], - [(1, 1), (2, 2), (None, 3), (4, 4)] - ), + ([(1, 1), (2, 2), (4, 4)], [(1, 1), (2, 2), (None, 3), (4, 4)]), ( [(1, 1), (2, 2), (None, None), (4, 4)], - [(1, 1), (2, 2), (None, 3), (4, 4)] - ) + [(1, 1), (2, 2), (None, 3), (4, 4)], + ), ), ( - ( - [(1, 1), (None, 7), (2, 2), (4, 4)], - [(1, 1), (2, 2), (None, 3), (4, 4)] - ), + ([(1, 1), (None, 7), (2, 2), (4, 4)], [(1, 1), (2, 2), (None, 3), (4, 4)]), ( [(1, 1), (None, 7), (2, 2), (None, None), (4, 4)], - [(1, 1), (None, None), (2, 2), (None, 3), (4, 4)] - ) + [(1, 1), (None, None), (2, 2), (None, 3), (4, 4)], + ), ), ( ( - [(791, 791), (4865, 8415), (374, 374), (389, 304), (279, 279), (282, 38681), (2668, None), (713, None)], - [(791, 791), (4865, 2466), (None, 8415), (374, 374), (389, 304), (279, 279), (282, 38681), (2668, None), (713, None)] + [ + (791, 791), + (4865, 8415), + (374, 374), + (389, 304), + (279, 279), + (282, 38681), + (2668, None), + (713, None), + ], + [ + (791, 791), + (4865, 2466), + (None, 8415), + (374, 374), + (389, 304), + (279, 279), + (282, 38681), + (2668, None), + (713, None), + ], ), ( - [(791, 791), (4865, 8415), (None, None), (374, 374), (389, 304), (279, 279), (282, 38681), (2668, None), (713, None)], - [(791, 791), (4865, 2466), (None, 8415), (374, 374), (389, 304), (279, 279), (282, 38681), (2668, None), (713, None)] - ) - ) - ] + [ + (791, 791), + (4865, 8415), + (None, None), + (374, 374), + (389, 304), + (279, 279), + (282, 38681), + (2668, None), + (713, None), + ], + [ + (791, 791), + (4865, 2466), + (None, 8415), + (374, 374), + (389, 304), + (279, 279), + (282, 38681), + (2668, None), + (713, None), + ], + ), + ), + ], ) def test_align_pairs(inputs, expected): """Test pairs alignment.""" pairs1, pairs2 = inputs - aligned_pairs1, aligned_pairs2 = evaluator.align_pairs(pairs1, pairs2, neutral_pair=(None, None)) + aligned_pairs1, aligned_pairs2 = evaluator.align_pairs( + pairs1, pairs2, neutral_pair=(None, None) + ) assert aligned_pairs1, aligned_pairs2 == expected @pytest.mark.parametrize( "inputs, expected", [ + (([(1, 1), (2, 3), (4, 5), (6, 6)], [(1, 1), (2, 2), (4, 7), (6, 6)]), [0]), ( - ( - [(1, 1), (2, 3), (4, 5), (6, 6)], - [(1, 1), (2, 2), (4, 7), (6, 6)] - ), - [0] - ), - ( - ( [(1, 1), (2, 3), (4, 5), (None, None), (6, 6)], - [(1, 1), (2, 3), (4, 5), (None, 6), (6, 7)] + [(1, 1), (2, 3), (4, 5), (None, 6), (6, 7)], ), - [1, 1] + [1, 1], ), ( - ( - [(1, 1), (2, 3), (5, 6), (7, 8)], - [(1, 1), (2, 4), (5, 6), (7, 9)] - ), - [0, 1, 0] + ([(1, 1), (2, 3), (5, 6), (7, 8)], [(1, 1), (2, 4), (5, 6), (7, 9)]), + [0, 1, 0], ), - ( ( - [(791, 791), (None, None), (4865, 8415), (374, 374), (389, 304), (279, 279), (282, None), (2668, None), (713, 38681)], - [(791, 791), (None, 2466), (4865, 8415), (374, 374), (389, 304), (279, 279), (282, None), (2668, None), (713, 38681)] + [ + (791, 791), + (None, None), + (4865, 8415), + (374, 374), + (389, 304), + (279, 279), + (282, None), + (2668, None), + (713, 38681), + ], + [ + (791, 791), + (None, 2466), + (4865, 8415), + (374, 374), + (389, 304), + (279, 279), + (282, None), + (2668, None), + (713, 38681), + ], ), - [1, 1, 1, 1, 1] + [1, 1, 1, 1, 1], ), ( ( - [(2127, 2127), (26997, 26997), (11, 11), (1556, 1556), (276, 276), (56692, 56692), (728, 728), (11, 11), (328, 328), (2357, 2357), (8977, 8977), (29222, None), (482, 2234), (48432, 3647), (301, 96383), (25, 25), (18002, 18002), (2298, 2298), (729, 729), (2357, 2357), (554, 554)], - [(2127, 2127), (26997, 26997), (11, None), (1556, None), (276, None), (56692, None), (728, None), (11, 11), (328, 328), (2357, 2357), (8977, 8977), (29222, None), (482, 2234), (48432, 3647), (301, 96383), (25, 25), (18002, 18002), (2298, 2298), (729, 729), (2357, 2357), (554, 554)] + [ + (2127, 2127), + (26997, 26997), + (11, 11), + (1556, 1556), + (276, 276), + (56692, 56692), + (728, 728), + (11, 11), + (328, 328), + (2357, 2357), + (8977, 8977), + (29222, None), + (482, 2234), + (48432, 3647), + (301, 96383), + (25, 25), + (18002, 18002), + (2298, 2298), + (729, 729), + (2357, 2357), + (554, 554), + ], + [ + (2127, 2127), + (26997, 26997), + (11, None), + (1556, None), + (276, None), + (56692, None), + (728, None), + (11, 11), + (328, 328), + (2357, 2357), + (8977, 8977), + (29222, None), + (482, 2234), + (48432, 3647), + (301, 96383), + (25, 25), + (18002, 18002), + (2298, 2298), + (729, 729), + (2357, 2357), + (554, 554), + ], ), - [1, 1, 1, 1] - ) - ] + [1, 1, 1, 1], + ), + ], ) def test_get_correction_true_positives(inputs, expected): """Test correction precison calculation. @@ -179,32 +339,43 @@ def test_get_correction_true_positives(inputs, expected): ref_pairs, pred_pairs ) assert correction_precision == expected - + + evaluator.get_correction_true_positives( - ref_pairs=[(791, 791), (None, None), (4865, 8415), (374, 374), (389, 304), (279, 279), (282, None), (2668, None), (713, 38681)], - pred_pairs=[(791, 791), (None, 2466), (4865, 8415), (374, 374), (389, 304), (279, 279), (282, None), (2668, None), (713, 38681)] + ref_pairs=[ + (791, 791), + (None, None), + (4865, 8415), + (374, 374), + (389, 304), + (279, 279), + (282, None), + (2668, None), + (713, 38681), + ], + pred_pairs=[ + (791, 791), + (None, 2466), + (4865, 8415), + (374, 374), + (389, 304), + (279, 279), + (282, None), + (2668, None), + (713, 38681), + ], ) @pytest.mark.parametrize( "inputs, expected", [ - ( - ["bœuf", "œuf"], - ["boeuf", "oeuf"] - ), - ( - ["colour", "flavour"], - ["color", "flavor"] - ), - ( - ["maïs", "â"], - ["mais", "a"] - ) - ] + (["bœuf", "œuf"], ["boeuf", "oeuf"]), + (["colour", "flavour"], ["color", "flavor"]), + (["maïs", "â"], ["mais", "a"]), + ], ) def test_normalize(inputs, expected): - """Text normalization before evaluation. - """ + """Text normalization before evaluation.""" normalized_texts = evaluator.normalize(inputs) - assert normalized_texts == expected \ No newline at end of file + assert normalized_texts == expected diff --git a/spellcheck/tests/test_processing.py b/spellcheck/tests/test_processing.py index 05eace9f..a5b2f038 100644 --- a/spellcheck/tests/test_processing.py +++ b/spellcheck/tests/test_processing.py @@ -9,80 +9,49 @@ ( ( "patata 15%, piment 15,2%, pamplemousse 47 %, aubergine 18.1 %", - "patata 15 %, piment 15,2 %, pamplemousse 47 %, aubergine 18.1%" + "patata 15 %, piment 15,2 %, pamplemousse 47 %, aubergine 18.1%", ), - "patata 15%, piment 15,2%, pamplemousse 47 %, aubergine 18.1 %" + "patata 15%, piment 15,2%, pamplemousse 47 %, aubergine 18.1 %", ), ( - ( - "escargot 14'%, olives 28 %", - "escargot 14%, olives 28%" - ), + ("escargot 14'%, olives 28 %", "escargot 14%, olives 28%"), "escargot 14%, olives 28 %", ), ( - ( - "farine 14&'!%, ravioli 47%", - "farine 14 %, ravioli 47 %" - ), + ("farine 14&'!%, ravioli 47%", "farine 14 %, ravioli 47 %"), "farine 14 %, ravioli 47%", ), ( - ( - "farine 14 %, ravioli 47%", - "farine 14'%, raviolo 47 %" - ), - "farine 14'%, raviolo 47%" + ("farine 14 %, ravioli 47%", "farine 14'%, raviolo 47 %"), + "farine 14'%, raviolo 47%", ), - ] + ], ) def test_align_whitespace_percentage(inputs, expected): reference, text = inputs aligned_text = DataProcessor._align_whitespace_percentage( - reference=reference, - text=text + reference=reference, text=text ) assert aligned_text == expected @pytest.mark.parametrize( - "inputs, expected", + "inputs, expected", [ - ( - ( - "oeuf, bœuf", - "œuf, boeuf" - ), - "oeuf, bœuf" - ), - ( - ( - "œuf, boeuf", - "oeuf, bœuf" - ), - "œuf, boeuf" - ), - ( - ( - "bœuf, œuf, ...", - "boeuf, œuf, ..." - ), - "bœuf, œuf, ..." - ), + (("oeuf, bœuf", "œuf, boeuf"), "oeuf, bœuf"), + (("œuf, boeuf", "oeuf, bœuf"), "œuf, boeuf"), + (("bœuf, œuf, ...", "boeuf, œuf, ..."), "bœuf, œuf, ..."), # NOTE: Doesn't work since I cannot detect a non-match. Is not a priority for now # ( # ( # "buf, œuf", # "boeuf, œuf" # ), - # "boeuf, œuf" + # "boeuf, œuf" # ), - ] + ], ) def test_align_oe(inputs, expected): reference, text = inputs - aligned_text = DataProcessor._align_oe( - reference=reference, - text=text - ) + aligned_text = DataProcessor._align_oe(reference=reference, text=text) assert aligned_text == expected diff --git a/spellcheck/tests/test_utils.py b/spellcheck/tests/test_utils.py index 67accb81..7265167b 100644 --- a/spellcheck/tests/test_utils.py +++ b/spellcheck/tests/test_utils.py @@ -1,10 +1,6 @@ import pytest -from spellcheck.utils import ( - get_repo_dir, - get_logger, - show_diff -) +from spellcheck.utils import get_repo_dir, get_logger, show_diff DELETED_ELEMENT = "#" @@ -23,19 +19,16 @@ def test_get_logger(): "test_input,expected", [ ( - ( - "hello world", - "hllo borld" - ), - f"h{DELETED_ELEMENT}llo borld" + ("hello world", "hllo borld"), + f"h{DELETED_ELEMENT}llo borld", ), - ] + ], ) def test_show_diff(test_input, expected): original, corrected = test_input highlighted_correction = show_diff( original_text=original, corrected_text=corrected, - deleted_element=DELETED_ELEMENT + deleted_element=DELETED_ELEMENT, ) assert highlighted_correction == expected