From 7066a3a3ba71b7cf1ed8639d59069f942c924086 Mon Sep 17 00:00:00 2001 From: Amy Roberts <22614925+amyeroberts@users.noreply.github.com> Date: Thu, 11 Apr 2024 14:55:17 +0200 Subject: [PATCH] Fix up --- utils/deprecate_models.py | 36 +++++++++++++++++++----------------- 1 file changed, 19 insertions(+), 17 deletions(-) diff --git a/utils/deprecate_models.py b/utils/deprecate_models.py index 166f555071beea..2719318c132253 100644 --- a/utils/deprecate_models.py +++ b/utils/deprecate_models.py @@ -8,14 +8,16 @@ import argparse import os from collections import defaultdict -from typing import Tuple, Optional +from pathlib import Path +from typing import Optional, Tuple import requests from git import Repo from packaging import version +from transformers import CONFIG_MAPPING, logging from transformers import __version__ as current_version -from transformers import logging, CONFIG_MAPPING + REPO_PATH = Path(os.path.abspath(os.path.dirname(os.path.dirname(__file__)))) repo = Repo(REPO_PATH) @@ -84,19 +86,19 @@ def get_model_doc_path(model: str) -> Tuple[Optional[str], Optional[str]]: model_doc_path = REPO_PATH / f"docs/source/en/model_doc/{model.replace('_', '-')}.md" if os.path.exists(model_doc_path): - return model_doc_path, model.replace('_', '-') + return model_doc_path, model.replace("_", "-") # Try replacing _ with "" in the model name model_doc_path = REPO_PATH / f"docs/source/en/model_doc/{model.replace('_', '')}.md" if os.path.exists(model_doc_path): - return model_doc_path, model.replace('_', '') + return model_doc_path, model.replace("_", "") return None, None def extract_model_info(model): - model_info = dict() + model_info = {} model_doc_path, model_doc_name = get_model_doc_path(model) model_path = REPO_PATH / f"src/transformers/models/{model}" @@ -173,7 +175,6 @@ def get_line_indent(s): maybe_else_block = [] in_else_block = False in_base_imports = False - base_import_block = [] open_indent_level = -1 # We iterate over each line in the init file to create a new init file @@ -290,7 +291,7 @@ def update_init_file(filename, models): f.write(init_file) -def remove_model_references_from_file(filename, models, condition=None): +def remove_model_references_from_file(filename, models, condition): """ Remove all references to the given models from the given file @@ -299,9 +300,6 @@ def remove_model_references_from_file(filename, models, condition=None): models (List[str]): The models to remove condition (Callable): A function that takes the line and model and returns True if the line should be removed """ - if condition is None: - condition = lambda line, model: model == line.strip() - with open(filename, "r") as f: init_file = f.read() @@ -372,8 +370,8 @@ def deprecate_models(models): for model, model_info in models_info.items(): if model in CONFIG_MAPPING: model_config_classes.append(CONFIG_MAPPING[model].__name__) - elif model_info['model_doc_name'] in CONFIG_MAPPING: - model_config_classes.append(CONFIG_MAPPING[model_info['model_doc_name']].__name__) + elif model_info["model_doc_name"] in CONFIG_MAPPING: + model_config_classes.append(CONFIG_MAPPING[model_info["model_doc_name"]].__name__) else: skipped_models.append(model) print(f"Model config class not found for model: {model}") @@ -385,7 +383,7 @@ def deprecate_models(models): print(f"Models to deprecate: {models}") # Remove model config classes from config check - print(f"Removing model config classes from config checks") + print("Removing model config classes from config checks") remove_model_config_classes_from_config_check("src/transformers/configuration_utils.py", model_config_classes) tip_message = build_tip_message(get_last_stable_minor_release()) @@ -407,13 +405,17 @@ def deprecate_models(models): # We do the following with all models passed at once to avoid having to re-write the file multiple times # Update the __init__.py file to point to the deprecated model. - print(f"Updating __init__.py file to point to the deprecated models") + print("Updating __init__.py file to point to the deprecated models") update_init_file("src/transformers/__init__.py", models) # Remove model references from other files - print(f"Removing model references from other files") - remove_model_references_from_file("src/transformers/models/__init__.py", models, lambda line, model: model == line.strip().strip(",")) - remove_model_references_from_file("utils/slow_documentation_tests.txt", models, lambda line, model: "/" + model + "/" in line) + print("Removing model references from other files") + remove_model_references_from_file( + "src/transformers/models/__init__.py", models, lambda line, model: model == line.strip().strip(",") + ) + remove_model_references_from_file( + "utils/slow_documentation_tests.txt", models, lambda line, model: "/" + model + "/" in line + ) if __name__ == "__main__":