From 723786a27c84270c98b813bbd90162fdb10c1977 Mon Sep 17 00:00:00 2001 From: jeswan <57466294+jeswan@users.noreply.github.com> Date: Mon, 18 Jan 2021 21:35:41 -0800 Subject: [PATCH] Switch export_model to use AutoModel and AutoTokenizer (#1260) * refactor export_model to use AutoModel and AutoTokenizer * use AutoConfig, AutoTokenizer, and AutoModel instead of jiant model_type * Switch to hf_pretrained_model_name_or_path. Remove unused tokenizer_path. Update notebooks with AutoClass changes. --- examples/notebooks/jiant_Basic_Example.ipynb | 10 +-- .../notebooks/jiant_EdgeProbing_Example.ipynb | 12 +-- .../jiant_MNLI_Diagnostic_Example.ipynb | 16 ++-- .../notebooks/jiant_Multi_Task_Example.ipynb | 10 +-- examples/notebooks/jiant_STILTs_Example.ipynb | 13 ++- examples/notebooks/jiant_XNLI_Example.ipynb | 13 ++- guides/tutorials/quick_start_main.md | 6 +- jiant/proj/main/export_model.py | 88 ++++++------------- jiant/proj/main/modeling/model_setup.py | 39 +++++--- jiant/proj/main/modeling/taskmodels.py | 2 +- jiant/proj/main/runscript.py | 5 +- jiant/proj/main/tokenize_and_cache.py | 15 ++-- jiant/proj/simple/runscript.py | 32 +++---- .../xtreme/subscripts/a_download_model.sh | 2 +- .../xtreme/subscripts/c_tokenize_and_cache.sh | 48 ++++------ .../benchmarks/xtreme/xtreme_submission.py | 1 - jiant/scripts/download_data/constants.py | 6 +- jiant/shared/model_resolution.py | 16 ++-- jiant/shared/model_setup.py | 38 -------- tests/proj/main/test_export_model.py | 19 ++-- tests/proj/simple/test_runscript.py | 8 +- 21 files changed, 150 insertions(+), 249 deletions(-) diff --git a/examples/notebooks/jiant_Basic_Example.ipynb b/examples/notebooks/jiant_Basic_Example.ipynb index 3f7b12738..5bcd74584 100644 --- a/examples/notebooks/jiant_Basic_Example.ipynb +++ b/examples/notebooks/jiant_Basic_Example.ipynb @@ -158,8 +158,8 @@ }, "outputs": [], "source": [ - "export_model.lookup_and_export_model(\n", - " model_type=\"roberta-base\",\n", + "export_model.export_model(\n", + " hf_pretrained_model_name_or_path=\"roberta-base\",\n", " output_base_path=\"./models/roberta-base\",\n", ")" ] @@ -191,8 +191,7 @@ "\n", "tokenize_and_cache.main(tokenize_and_cache.RunConfiguration(\n", " task_config_path=f\"./tasks/configs/{task_name}_config.json\",\n", - " model_type=\"roberta-base\",\n", - " model_tokenizer_path=\"./models/roberta-base/tokenizer\",\n", + " hf_pretrained_model_name_or_path=\"roberta-base\",\n", " output_dir=f\"./cache/{task_name}\",\n", " phases=[\"train\", \"val\"],\n", "))" @@ -309,10 +308,9 @@ "run_args = main_runscript.RunConfiguration(\n", " jiant_task_container_config_path=\"./run_configs/mrpc_run_config.json\",\n", " output_dir=\"./runs/mrpc\",\n", - " model_type=\"roberta-base\",\n", + " hf_pretrained_model_name_or_path=\"roberta-base\",\n", " model_path=\"./models/roberta-base/model/roberta-base.p\",\n", " model_config_path=\"./models/roberta-base/model/roberta-base.json\",\n", - " model_tokenizer_path=\"./models/roberta-base/tokenizer\",\n", " learning_rate=1e-5,\n", " eval_every_steps=500,\n", " do_train=True,\n", diff --git a/examples/notebooks/jiant_EdgeProbing_Example.ipynb b/examples/notebooks/jiant_EdgeProbing_Example.ipynb index 0defefa03..e3ffd5986 100644 --- a/examples/notebooks/jiant_EdgeProbing_Example.ipynb +++ b/examples/notebooks/jiant_EdgeProbing_Example.ipynb @@ -2702,8 +2702,8 @@ "outputId": "c21bdffa-0ff3-49f3-e734-af5530ab4711" }, "source": [ - "export_model.lookup_and_export_model(\n", - " model_type=\"roberta-base\",\n", + "export_model.export_model(\n", + " hf_pretrained_model_name_or_path=\"roberta-base\",\n", " output_base_path=\"./models/roberta-base\",\n", ")" ], @@ -2856,8 +2856,7 @@ "\n", "tokenize_and_cache.main(tokenize_and_cache.RunConfiguration(\n", " task_config_path=f\"./tasks/configs/{task_name}_config.json\",\n", - " model_type=\"roberta-base\",\n", - " model_tokenizer_path=\"./models/roberta-base/tokenizer\",\n", + " hf_pretrained_model_name_or_path=\"roberta-base\",\n", " output_dir=f\"./cache/{task_name}\",\n", " phases=[\"train\", \"val\"],\n", "))" @@ -3147,10 +3146,9 @@ "run_args = main_runscript.RunConfiguration(\n", " jiant_task_container_config_path=\"./run_configs/semeval_run_config.json\",\n", " output_dir=\"./runs/semeval\",\n", - " model_type=\"roberta-base\",\n", + " hf_pretrained_model_name_or_path=\"roberta-base\",\n", " model_path=\"./models/roberta-base/model/roberta-base.p\",\n", " model_config_path=\"./models/roberta-base/model/roberta-base.json\",\n", - " model_tokenizer_path=\"./models/roberta-base/tokenizer\",\n", " learning_rate=1e-5,\n", " eval_every_steps=500,\n", " do_train=True,\n", @@ -3170,7 +3168,6 @@ " model_type: roberta-base\n", " model_path: ./models/roberta-base/model/roberta-base.p\n", " model_config_path: ./models/roberta-base/model/roberta-base.json\n", - " model_tokenizer_path: ./models/roberta-base/tokenizer\n", " model_load_mode: from_transformers\n", " do_train: True\n", " do_val: True\n", @@ -3204,7 +3201,6 @@ " \"model_type\": \"roberta-base\",\n", " \"model_path\": \"./models/roberta-base/model/roberta-base.p\",\n", " \"model_config_path\": \"./models/roberta-base/model/roberta-base.json\",\n", - " \"model_tokenizer_path\": \"./models/roberta-base/tokenizer\",\n", " \"model_load_mode\": \"from_transformers\",\n", " \"do_train\": true,\n", " \"do_val\": true,\n", diff --git a/examples/notebooks/jiant_MNLI_Diagnostic_Example.ipynb b/examples/notebooks/jiant_MNLI_Diagnostic_Example.ipynb index 5caec04ee..a3a1eaff2 100644 --- a/examples/notebooks/jiant_MNLI_Diagnostic_Example.ipynb +++ b/examples/notebooks/jiant_MNLI_Diagnostic_Example.ipynb @@ -140,8 +140,8 @@ }, "outputs": [], "source": [ - "export_model.lookup_and_export_model(\n", - " model_type=\"roberta-base\",\n", + "export_model.export_model(\n", + " hf_pretrained_model_name_or_path=\"roberta-base\",\n", " output_base_path=\"./models/roberta-base\",\n", ")" ] @@ -169,24 +169,21 @@ "# Tokenize and cache each task\n", "tokenize_and_cache.main(tokenize_and_cache.RunConfiguration(\n", " task_config_path=f\"./tasks/configs/mnli_config.json\",\n", - " model_type=\"roberta-base\",\n", - " model_tokenizer_path=\"./models/roberta-base/tokenizer\",\n", + " hf_pretrained_model_name_or_path=\"roberta-base\",\n", " output_dir=f\"./cache/mnli\",\n", " phases=[\"train\", \"val\"],\n", "))\n", "\n", "tokenize_and_cache.main(tokenize_and_cache.RunConfiguration(\n", " task_config_path=f\"./tasks/configs/mnli_mismatched_config.json\",\n", - " model_type=\"roberta-base\",\n", - " model_tokenizer_path=\"./models/roberta-base/tokenizer\",\n", + " hf_pretrained_model_name_or_path=\"roberta-base\",\n", " output_dir=f\"./cache/mnli_mismatched\",\n", " phases=[\"val\"],\n", "))\n", "\n", "tokenize_and_cache.main(tokenize_and_cache.RunConfiguration(\n", " task_config_path=f\"./tasks/configs/glue_diagnostics_config.json\",\n", - " model_type=\"roberta-base\",\n", - " model_tokenizer_path=\"./models/roberta-base/tokenizer\",\n", + " hf_pretrained_model_name_or_path=\"roberta-base\",\n", " output_dir=f\"./cache/glue_diagnostics\",\n", " phases=[\"test\"],\n", "))" @@ -323,10 +320,9 @@ "run_args = main_runscript.RunConfiguration(\n", " jiant_task_container_config_path=\"./run_configs/jiant_run_config.json\",\n", " output_dir=\"./runs/run1\",\n", - " model_type=\"roberta-base\",\n", + " hf_pretrained_model_name_or_path=\"roberta-base\",\n", " model_path=\"./models/roberta-base/model/roberta-base.p\",\n", " model_config_path=\"./models/roberta-base/model/roberta-base.json\",\n", - " model_tokenizer_path=\"./models/roberta-base/tokenizer\",\n", " learning_rate=1e-5,\n", " eval_every_steps=500,\n", " do_train=True,\n", diff --git a/examples/notebooks/jiant_Multi_Task_Example.ipynb b/examples/notebooks/jiant_Multi_Task_Example.ipynb index d4b41e674..1ebbf465c 100644 --- a/examples/notebooks/jiant_Multi_Task_Example.ipynb +++ b/examples/notebooks/jiant_Multi_Task_Example.ipynb @@ -161,8 +161,8 @@ }, "outputs": [], "source": [ - "export_model.lookup_and_export_model(\n", - " model_type=\"roberta-base\",\n", + "export_model.export_model(\n", + " hf_pretrained_model_name_or_path=\"roberta-base\",\n", " output_base_path=\"./models/roberta-base\",\n", ")" ] @@ -193,8 +193,7 @@ "for task_name in [\"rte\", \"stsb\", \"commonsenseqa\"]:\n", " tokenize_and_cache.main(tokenize_and_cache.RunConfiguration(\n", " task_config_path=f\"./tasks/configs/{task_name}_config.json\",\n", - " model_type=\"roberta-base\",\n", - " model_tokenizer_path=\"./models/roberta-base/tokenizer\",\n", + " hf_pretrained_model_name_or_path=\"roberta-base\",\n", " output_dir=f\"./cache/{task_name}\",\n", " phases=[\"train\", \"val\"],\n", " ))" @@ -342,10 +341,9 @@ "run_args = main_runscript.RunConfiguration(\n", " jiant_task_container_config_path=\"./run_configs/jiant_run_config.json\",\n", " output_dir=\"./runs/run1\",\n", - " model_type=\"roberta-base\",\n", + " hf_pretrained_model_name_or_path=\"roberta-base\",\n", " model_path=\"./models/roberta-base/model/roberta-base.p\",\n", " model_config_path=\"./models/roberta-base/model/roberta-base.json\",\n", - " model_tokenizer_path=\"./models/roberta-base/tokenizer\",\n", " learning_rate=1e-5,\n", " eval_every_steps=500,\n", " do_train=True,\n", diff --git a/examples/notebooks/jiant_STILTs_Example.ipynb b/examples/notebooks/jiant_STILTs_Example.ipynb index 03c32ca18..2292f24f4 100644 --- a/examples/notebooks/jiant_STILTs_Example.ipynb +++ b/examples/notebooks/jiant_STILTs_Example.ipynb @@ -163,8 +163,8 @@ }, "outputs": [], "source": [ - "export_model.lookup_and_export_model(\n", - " model_type=\"roberta-base\",\n", + "export_model.export_model(\n", + " hf_pretrained_model_name_or_path=\"roberta-base\",\n", " output_base_path=\"./models/roberta-base\",\n", ")" ] @@ -195,8 +195,7 @@ "for task_name in [\"mnli\", \"rte\"]:\n", " tokenize_and_cache.main(tokenize_and_cache.RunConfiguration(\n", " task_config_path=f\"./tasks/configs/{task_name}_config.json\",\n", - " model_type=\"roberta-base\",\n", - " model_tokenizer_path=\"./models/roberta-base/tokenizer\",\n", + " hf_pretrained_model_name_or_path=\"roberta-base\",\n", " output_dir=f\"./cache/{task_name}\",\n", " phases=[\"train\", \"val\"],\n", " ))" @@ -367,10 +366,9 @@ "run_args = main_runscript.RunConfiguration(\n", " jiant_task_container_config_path=\"./run_configs/mnli_run_config.json\",\n", " output_dir=\"./runs/mnli\",\n", - " model_type=\"roberta-base\",\n", + " hf_pretrained_model_name_or_path=\"roberta-base\",\n", " model_path=\"./models/roberta-base/model/roberta-base.p\",\n", " model_config_path=\"./models/roberta-base/model/roberta-base.json\",\n", - " model_tokenizer_path=\"./models/roberta-base/tokenizer\",\n", " learning_rate=1e-5,\n", " eval_every_steps=500,\n", " do_train=True,\n", @@ -404,11 +402,10 @@ "run_args = main_runscript.RunConfiguration(\n", " jiant_task_container_config_path=\"./run_configs/rte_run_config.json\",\n", " output_dir=\"./runs/mnli___rte\",\n", - " model_type=\"roberta-base\",\n", + " hf_pretrained_model_name_or_path=\"roberta-base\",\n", " model_path=\"./runs/mnli/best_model.p\", # Loading the best model\n", " model_load_mode=\"partial\",\n", " model_config_path=\"./models/roberta-base/model/roberta-base.json\",\n", - " model_tokenizer_path=\"./models/roberta-base/tokenizer\",\n", " learning_rate=1e-5,\n", " eval_every_steps=500,\n", " do_train=True,\n", diff --git a/examples/notebooks/jiant_XNLI_Example.ipynb b/examples/notebooks/jiant_XNLI_Example.ipynb index d7e7e84da..975f0f5a0 100644 --- a/examples/notebooks/jiant_XNLI_Example.ipynb +++ b/examples/notebooks/jiant_XNLI_Example.ipynb @@ -164,8 +164,8 @@ }, "outputs": [], "source": [ - "export_model.lookup_and_export_model(\n", - " model_type=\"xlm-roberta-base\",\n", + "export_model.export_model(\n", + " hf_pretrained_model_name_or_path=\"xlm-roberta-base\",\n", " output_base_path=\"./models/xlm-roberta-base\",\n", ")" ] @@ -197,8 +197,7 @@ "# Tokenize and cache MNLI\n", "tokenize_and_cache.main(tokenize_and_cache.RunConfiguration(\n", " task_config_path=f\"./tasks/configs/mnli_config.json\",\n", - " model_type=\"xlm-roberta-base\",\n", - " model_tokenizer_path=\"./models/xlm-roberta-base/tokenizer\",\n", + " hf_pretrained_model_name_or_path=\"xlm-roberta-base\",\n", " output_dir=f\"./cache/mnli\",\n", " phases=[\"train\", \"val\"],\n", "))\n", @@ -207,8 +206,7 @@ "for lang in [\"de\", \"zh\"]:\n", " tokenize_and_cache.main(tokenize_and_cache.RunConfiguration(\n", " task_config_path=f\"./tasks/configs/xnli_{lang}_config.json\",\n", - " model_type=\"xlm-roberta-base\",\n", - " model_tokenizer_path=\"./models/xlm-roberta-base/tokenizer\",\n", + " hf_pretrained_model_name_or_path=\"xlm-roberta-base\",\n", " output_dir=f\"./cache/xnli_{lang}\",\n", " phases=[\"val\"],\n", " ))" @@ -384,10 +382,9 @@ "run_args = main_runscript.RunConfiguration(\n", " jiant_task_container_config_path=\"./run_configs/jiant_run_config.json\",\n", " output_dir=\"./runs/run1\",\n", - " model_type=\"xlm-roberta-base\",\n", + " hf_pretrained_model_name_or_path=\"xlm-roberta-base\",\n", " model_path=\"./models/xlm-roberta-base/model/xlm-roberta-base.p\",\n", " model_config_path=\"./models/xlm-roberta-base/model/xlm-roberta-base.json\",\n", - " model_tokenizer_path=\"./models/xlm-roberta-base/tokenizer\",\n", " learning_rate=1e-5,\n", " eval_every_steps=500,\n", " do_train=True,\n", diff --git a/guides/tutorials/quick_start_main.md b/guides/tutorials/quick_start_main.md index 100463719..e88ba33f3 100644 --- a/guides/tutorials/quick_start_main.md +++ b/guides/tutorials/quick_start_main.md @@ -26,7 +26,7 @@ python jiant/scripts/download_data/runscript.py \ 2. Next, we download our RoBERTa-base model ```bash python jiant/proj/main/export_model.py \ - --model_type ${MODEL_TYPE} \ + --hf_pretrained_model_name_or_path ${MODEL_TYPE} \ --output_base_path ${EXP_DIR}/models/${MODEL_TYPE} ``` @@ -34,9 +34,7 @@ python jiant/proj/main/export_model.py \ ```bash python jiant/proj/main/tokenize_and_cache.py \ --task_config_path ${EXP_DIR}/tasks/configs/${TASK}_config.json \ - --model_type ${MODEL_TYPE} \ - --model_tokenizer_path \ - ${EXP_DIR}/models/${MODEL_TYPE}/tokenizer \ + --hf_pretrained_model_name_or_path ${MODEL_TYPE} \ --output_dir ${EXP_DIR}/cache/${MODEL_TYPE}/${TASK} \ --phases train,val \ --max_seq_length 256 \ diff --git a/jiant/proj/main/export_model.py b/jiant/proj/main/export_model.py index ee9ab8bb9..7ceae03a8 100644 --- a/jiant/proj/main/export_model.py +++ b/jiant/proj/main/export_model.py @@ -1,8 +1,7 @@ import os -from typing import Tuple, Type import torch -import transformers +from transformers import AutoModelForPreTraining, AutoTokenizer import jiant.utils.python.io as py_io import jiant.utils.zconf as zconf @@ -10,28 +9,12 @@ @zconf.run_config class RunConfiguration(zconf.RunConfig): - model_type = zconf.attr(type=str) + hf_pretrained_model_name_or_path = zconf.attr(type=str) output_base_path = zconf.attr(type=str) - hf_model_name = zconf.attr(type=str, default=None) - - -def lookup_and_export_model(model_type: str, output_base_path: str, hf_model_name: str = None): - model_class, tokenizer_class = get_model_and_tokenizer_classes(model_type) - export_model( - model_type=model_type, - output_base_path=output_base_path, - model_class=model_class, - tokenizer_class=tokenizer_class, - hf_model_name=hf_model_name, - ) def export_model( - model_type: str, - output_base_path: str, - model_class: Type[transformers.PreTrainedModel], - tokenizer_class: Type[transformers.PreTrainedTokenizer], - hf_model_name: str = None, + hf_pretrained_model_name_or_path: str, output_base_path: str, ): """Retrieve model and tokenizer from Transformers and save all necessary data Things saved: @@ -40,66 +23,51 @@ def export_model( - Tokenizer data - JSON file pointing to paths for the above Args: - model_type: Model-type string. See: `get_model_and_tokenizer_classes` + pretrained_model_name_or_path (:obj:`str` or :obj:`os.PathLike`): + Can be either: + + - A string, the `model id` of a pretrained model configuration + hosted inside a model repo on okhuggingface.co. + Valid model ids can be located at the root-level, like + ``bert-base-uncased``, or namespaced under a user + or organization name, like ``dbmdz/bert-base-german-cased``. + - A path to a `directory` containing a configuration file saved using + the :meth:`~transformers.PretrainedConfig.save_pretrained` method, + or the + :meth:`~transformers.PreTrainedModel.save_pretrained` method, + e.g., ``./my_model_directory/``. + - A path or url to a saved configuration JSON `file`, e.g., + ``./my_model_directory/configuration.json``. output_base_path: Base path to save output to - model_class: Model class - tokenizer_class: Tokenizer class - hf_model_name: (Optional) hf_model_name from https://huggingface.co/models, - if it differs from model_type """ - if hf_model_name is None: - hf_model_name = model_type + model = AutoModelForPreTraining.from_pretrained(hf_pretrained_model_name_or_path) + model_type = model.config_class.model_type - tokenizer_fol_path = os.path.join(output_base_path, "tokenizer") model_fol_path = os.path.join(output_base_path, "model") + model_path = os.path.join(model_fol_path, f"{model_type}.p") + model_config_path = os.path.join(model_fol_path, f"{model_type}.json") + tokenizer_fol_path = os.path.join(output_base_path, "tokenizer") + os.makedirs(tokenizer_fol_path, exist_ok=True) os.makedirs(model_fol_path, exist_ok=True) - model_path = os.path.join(model_fol_path, f"{model_type}.p") - model_config_path = os.path.join(model_fol_path, f"{model_type}.json") - model = model_class.from_pretrained(hf_model_name) torch.save(model.state_dict(), model_path) py_io.write_json(model.config.to_dict(), model_config_path) - tokenizer = tokenizer_class.from_pretrained(hf_model_name) + tokenizer = AutoTokenizer.from_pretrained(hf_pretrained_model_name_or_path) tokenizer.save_pretrained(tokenizer_fol_path) config = { "model_type": model_type, "model_path": model_path, "model_config_path": model_config_path, - "model_tokenizer_path": tokenizer_fol_path, } - py_io.write_json(config, os.path.join(output_base_path, f"config.json")) - - -def get_model_and_tokenizer_classes( - model_type: str, -) -> Tuple[Type[transformers.PreTrainedModel], Type[transformers.PreTrainedTokenizer]]: - # We want the chosen model to have all the weights from pretraining (if possible) - class_lookup = { - "bert": (transformers.BertForPreTraining, transformers.BertTokenizer), - "xlm-clm-": (transformers.XLMWithLMHeadModel, transformers.XLMTokenizer), - "roberta": (transformers.RobertaForMaskedLM, transformers.RobertaTokenizer), - "albert": (transformers.AlbertForMaskedLM, transformers.AlbertTokenizer), - "bart": (transformers.BartForConditionalGeneration, transformers.BartTokenizer), - "mbart": (transformers.BartForConditionalGeneration, transformers.MBartTokenizer), - "electra": (transformers.ElectraForPreTraining, transformers.ElectraTokenizer), - } - if model_type.split("-")[0] in class_lookup: - return class_lookup[model_type.split("-")[0]] - elif model_type.startswith("xlm-mlm-") or model_type.startswith("xlm-clm-"): - return transformers.XLMWithLMHeadModel, transformers.XLMTokenizer - elif model_type.startswith("xlm-roberta-"): - return transformers.XLMRobertaForMaskedLM, transformers.XLMRobertaTokenizer - else: - raise KeyError() + py_io.write_json(config, os.path.join(output_base_path, "config.json")) def main(): args = RunConfiguration.default_run_cli() - lookup_and_export_model( - model_type=args.model_type, + export_model( + hf_pretrained_model_name_or_path=args.hf_pretrained_model_name_or_path, output_base_path=args.output_base_path, - hf_model_name=args.hf_model_name, ) diff --git a/jiant/proj/main/modeling/model_setup.py b/jiant/proj/main/modeling/model_setup.py index 6b56710ec..6ff546e72 100644 --- a/jiant/proj/main/modeling/model_setup.py +++ b/jiant/proj/main/modeling/model_setup.py @@ -1,5 +1,8 @@ from dataclasses import dataclass -from typing import Any, Dict, List, Optional +from typing import Any +from typing import Dict +from typing import List +from typing import Optional import torch import torch.nn as nn @@ -7,28 +10,41 @@ import jiant.proj.main.components.container_setup as container_setup +import jiant.proj.main.modeling.heads as heads import jiant.proj.main.modeling.primary as primary import jiant.proj.main.modeling.taskmodels as taskmodels -import jiant.proj.main.modeling.heads as heads -import jiant.shared.model_setup as model_setup import jiant.utils.python.strings as strings -from jiant.shared.model_setup import ModelArchitectures -from jiant.tasks import Task, TaskTypes + +from jiant.shared.model_resolution import ModelArchitectures +from jiant.tasks import Task +from jiant.tasks import TaskTypes def setup_jiant_model( - model_type: str, + hf_pretrained_model_name_or_path: str, model_config_path: str, - tokenizer_path: str, task_dict: Dict[str, Task], taskmodels_config: container_setup.TaskmodelsConfig, ): """Sets up tokenizer, encoder, and task models, and instantiates and returns a JiantModel. Args: - model_type (str): model shortcut name. + hf_pretrained_model_name_or_path (:obj:`str` or :obj:`os.PathLike`): + Can be either: + + - A string, the `model id` of a predefined tokenizer hosted inside a model + repo on huggingface.co. Valid model ids can be located at the root-level, + like ``bert-base-uncased``, or namespaced under + a user or organization name, like ``dbmdz/bert-base-german-cased``. + - A path to a `directory` containing vocabulary files required by the + tokenizer, for instance saved using the + :func:`~transformers.PreTrainedTokenizer.save_pretrained` method, e.g., + ``./my_model_directory/``. + - A path or url to a single saved vocabulary file if and only if + the tokenizer only requires a single vocabulary file (like Bert or XLNet), + e.g.: ``./my_model_directory/vocab.txt``. (Not + applicable to all derived classes) model_config_path (str): Path to the JSON file containing the configuration parameters. - tokenizer_path (str): path to tokenizer directory. task_dict (Dict[str, tasks.Task]): map from task name to task instance. taskmodels_config: maps mapping from tasks to models, and specifying task-model configs. @@ -36,9 +52,10 @@ def setup_jiant_model( JiantModel nn.Module. """ - model_arch = ModelArchitectures.from_model_type(model_type) + model = transformers.AutoModel.from_pretrained(hf_pretrained_model_name_or_path) + model_arch = ModelArchitectures.from_model_type(model.base_model_prefix) transformers_class_spec = TRANSFORMERS_CLASS_SPEC_DICT[model_arch] - tokenizer = model_setup.get_tokenizer(model_type=model_type, tokenizer_path=tokenizer_path) + tokenizer = transformers.AutoTokenizer.from_pretrained(hf_pretrained_model_name_or_path) ancestor_model = get_ancestor_model( transformers_class_spec=transformers_class_spec, model_config_path=model_config_path, ) diff --git a/jiant/proj/main/modeling/taskmodels.py b/jiant/proj/main/modeling/taskmodels.py index f656229d9..53d5e70c7 100644 --- a/jiant/proj/main/modeling/taskmodels.py +++ b/jiant/proj/main/modeling/taskmodels.py @@ -9,7 +9,7 @@ import jiant.utils.transformer_utils as transformer_utils from jiant.proj.main.components.outputs import LogitsOutput, LogitsAndLossOutput from jiant.utils.python.datastructures import take_one -from jiant.shared.model_setup import ModelArchitectures +from jiant.shared.model_resolution import ModelArchitectures class Taskmodel(nn.Module, metaclass=abc.ABCMeta): diff --git a/jiant/proj/main/runscript.py b/jiant/proj/main/runscript.py index 705e39cda..11ee44ad9 100644 --- a/jiant/proj/main/runscript.py +++ b/jiant/proj/main/runscript.py @@ -21,10 +21,10 @@ class RunConfiguration(zconf.RunConfig): output_dir = zconf.attr(type=str, required=True) # === Model parameters === # + hf_pretrained_model_name_or_path = zconf.attr(type=str, required=True) model_type = zconf.attr(type=str, required=True) model_path = zconf.attr(type=str, required=True) model_config_path = zconf.attr(default=None, type=str) - model_tokenizer_path = zconf.attr(default=None, type=str) model_load_mode = zconf.attr(default="from_transformers", type=str) # === Running Setup === # @@ -85,9 +85,8 @@ def setup_runner( with distributed.only_first_process(local_rank=args.local_rank): # load the model jiant_model = jiant_model_setup.setup_jiant_model( - model_type=args.model_type, + hf_pretrained_model_name_or_path=args.hf_pretrained_model_name_or_path, model_config_path=args.model_config_path, - tokenizer_path=args.model_tokenizer_path, task_dict=jiant_task_container.task_dict, taskmodels_config=jiant_task_container.taskmodels_config, ) diff --git a/jiant/proj/main/tokenize_and_cache.py b/jiant/proj/main/tokenize_and_cache.py index 9fe19bdf6..51688b735 100644 --- a/jiant/proj/main/tokenize_and_cache.py +++ b/jiant/proj/main/tokenize_and_cache.py @@ -1,9 +1,10 @@ import os +from transformers import AutoConfig, AutoTokenizer + import jiant.proj.main.preprocessing as preprocessing import jiant.shared.caching as shared_caching import jiant.shared.model_resolution as model_resolution -import jiant.shared.model_setup as model_setup import jiant.tasks as tasks import jiant.tasks.evaluate as evaluate import jiant.utils.zconf as zconf @@ -15,8 +16,7 @@ class RunConfiguration(zconf.RunConfig): # === Required parameters === # task_config_path = zconf.attr(type=str, required=True) - model_type = zconf.attr(type=str, required=True) - model_tokenizer_path = zconf.attr(type=str, required=True) + hf_pretrained_model_name_or_path = zconf.attr(type=str, required=True) output_dir = zconf.attr(type=str, required=True) # === Optional parameters === # @@ -142,13 +142,14 @@ def iter_chunk_and_save(task, phase, examples, feat_spec, tokenizer, args: RunCo def main(args: RunConfiguration): + config = AutoConfig.from_pretrained(args.hf_pretrained_model_name_or_path) + model_type = config.model_type + task = tasks.create_task_from_config_path(config_path=args.task_config_path, verbose=True) feat_spec = model_resolution.build_featurization_spec( - model_type=args.model_type, max_seq_length=args.max_seq_length, - ) - tokenizer = model_setup.get_tokenizer( - model_type=args.model_type, tokenizer_path=args.model_tokenizer_path, + model_type=model_type, max_seq_length=args.max_seq_length, ) + tokenizer = AutoTokenizer.from_pretrained(args.hf_pretrained_model_name_or_path) if isinstance(args.phases, str): phases = args.phases.split(",") else: diff --git a/jiant/proj/simple/runscript.py b/jiant/proj/simple/runscript.py index 8d77c6f92..d534ee50a 100644 --- a/jiant/proj/simple/runscript.py +++ b/jiant/proj/simple/runscript.py @@ -2,6 +2,8 @@ import torch +from transformers import AutoConfig + import jiant.proj.main.write_task_configs as write_task_configs import jiant.proj.main.export_model as export_model import jiant.proj.main.tokenize_and_cache as tokenize_and_cache @@ -21,7 +23,7 @@ class RunConfiguration(zconf.RunConfig): data_dir = zconf.attr(type=str, required=True) # === Model parameters === # - model_type = zconf.attr(type=str, required=True) + hf_pretrained_model_name_or_path = zconf.attr(type=str, required=True) model_weights_path = zconf.attr(type=str, default=None) model_cache_path = zconf.attr(type=str, default=None) @@ -100,6 +102,7 @@ def create_and_write_task_configs(task_name_list, data_dir, task_config_base_pat def run_simple(args: RunConfiguration, with_continue: bool = False): + hf_config = AutoConfig.from_pretrained(args.hf_pretrained_model_name_or_path) model_cache_path = replace_none( args.model_cache_path, default=os.path.join(args.exp_dir, "models") @@ -122,11 +125,11 @@ def run_simple(args: RunConfiguration, with_continue: bool = False): ) # === Step 2: Download models === # - if not os.path.exists(os.path.join(model_cache_path, args.model_type)): + if not os.path.exists(os.path.join(model_cache_path, hf_config.model_type)): print("Downloading model") - export_model.lookup_and_export_model( - model_type=args.model_type, - output_base_path=os.path.join(model_cache_path, args.model_type), + export_model.export_model( + hf_pretrained_model_name_or_path=args.hf_pretrained_model_name_or_path, + output_base_path=os.path.join(model_cache_path, hf_config.model_type), ) # === Step 3: Tokenize and cache === # @@ -139,7 +142,7 @@ def run_simple(args: RunConfiguration, with_continue: bool = False): phases_to_do = [] for phase, phase_task_list in phase_task_dict.items(): if task_name in phase_task_list and not os.path.exists( - os.path.join(args.exp_dir, "cache", args.model_type, task_name, phase) + os.path.join(args.exp_dir, "cache", hf_config.model_type, task_name, phase) ): phases_to_do.append(phase) if not phases_to_do: @@ -148,11 +151,8 @@ def run_simple(args: RunConfiguration, with_continue: bool = False): tokenize_and_cache.main( tokenize_and_cache.RunConfiguration( task_config_path=task_config_path_dict[task_name], - model_type=args.model_type, - model_tokenizer_path=os.path.join( - model_cache_path, args.model_type, "tokenizer" - ), - output_dir=os.path.join(args.exp_dir, "cache", args.model_type, task_name), + hf_pretrained_model_name_or_path=args.hf_pretrained_model_name_or_path, + output_dir=os.path.join(args.exp_dir, "cache", hf_config.model_type, task_name), phases=phases_to_do, # TODO: Need a strategy for task-specific max_seq_length issues (issue #1176) max_seq_length=args.max_seq_length, @@ -166,7 +166,7 @@ def run_simple(args: RunConfiguration, with_continue: bool = False): # number of moving parts. jiant_task_container_config = configurator.SimpleAPIMultiTaskConfigurator( task_config_base_path=os.path.join(args.data_dir, "configs"), - task_cache_base_path=os.path.join(args.exp_dir, "cache", args.model_type), + task_cache_base_path=os.path.join(args.exp_dir, "cache", hf_config.model_type), train_task_name_list=args.train_tasks, val_task_name_list=args.val_tasks, test_task_name_list=args.test_tasks, @@ -193,7 +193,7 @@ def run_simple(args: RunConfiguration, with_continue: bool = False): else: model_load_mode = "from_transformers" model_weights_path = os.path.join( - model_cache_path, args.model_type, "model", f"{args.model_type}.p" + model_cache_path, hf_config.model_type, "model", f"{hf_config.model_type}.p" ) run_output_dir = os.path.join(args.exp_dir, "runs", args.run_name) @@ -212,12 +212,12 @@ def run_simple(args: RunConfiguration, with_continue: bool = False): jiant_task_container_config_path=jiant_task_container_config_path, output_dir=run_output_dir, # === Model parameters === # - model_type=args.model_type, + hf_pretrained_model_name_or_path=args.hf_pretrained_model_name_or_path, + model_type=hf_config.model_type, model_path=model_weights_path, model_config_path=os.path.join( - model_cache_path, args.model_type, "model", f"{args.model_type}.json" + model_cache_path, hf_config.model_type, "model", f"{hf_config.model_type}.json" ), - model_tokenizer_path=os.path.join(model_cache_path, args.model_type, "tokenizer"), model_load_mode=model_load_mode, # === Running Setup === # do_train=bool(args.train_tasks), diff --git a/jiant/scripts/benchmarks/xtreme/subscripts/a_download_model.sh b/jiant/scripts/benchmarks/xtreme/subscripts/a_download_model.sh index 2d177e6eb..d9605b8de 100644 --- a/jiant/scripts/benchmarks/xtreme/subscripts/a_download_model.sh +++ b/jiant/scripts/benchmarks/xtreme/subscripts/a_download_model.sh @@ -6,5 +6,5 @@ # This downloads a model (e.g. xlm-roberta-large) python jiant/proj/main/export_model.py \ - --model_type ${MODEL_TYPE} \ + --hf_pretrained_model_name_or_path ${MODEL_TYPE} \ --output_base_path ${BASE_PATH}/models/${MODEL_TYPE} diff --git a/jiant/scripts/benchmarks/xtreme/subscripts/c_tokenize_and_cache.sh b/jiant/scripts/benchmarks/xtreme/subscripts/c_tokenize_and_cache.sh index 0d51a18c6..48e95cb59 100644 --- a/jiant/scripts/benchmarks/xtreme/subscripts/c_tokenize_and_cache.sh +++ b/jiant/scripts/benchmarks/xtreme/subscripts/c_tokenize_and_cache.sh @@ -10,8 +10,7 @@ for LG in ar bg de el en es fr hi ru sw th tr ur vi zh; do TASK=xnli_${LG} python jiant/proj/main/tokenize_and_cache.py \ --task_config_path ${BASE_PATH}/tasks/configs/${TASK}_config.json \ - --model_type ${MODEL_TYPE} \ - --model_tokenizer_path ${BASE_PATH}/models/${MODEL_TYPE}/tokenizer \ + --hf_pretrained_model_name_or_path ${MODEL_TYPE} \ --output_dir ${BASE_PATH}/cache/${MODEL_TYPE}/${TASK} \ --phases val,test \ --max_seq_length 256 \ @@ -22,8 +21,7 @@ done TASK=pawsx_en python jiant/proj/main/tokenize_and_cache.py \ --task_config_path ${BASE_PATH}/tasks/configs/${TASK}_config.json \ - --model_type ${MODEL_TYPE} \ - --model_tokenizer_path ${BASE_PATH}/models/${MODEL_TYPE}/tokenizer \ + --hf_pretrained_model_name_or_path ${MODEL_TYPE} \ --output_dir ${BASE_PATH}/cache/${MODEL_TYPE}/${TASK} \ --phases train,val,test \ --max_seq_length 256 \ @@ -32,8 +30,7 @@ for LG in ar de es fr ja ko zh; do TASK=pawsx_${LG} python jiant/proj/main/tokenize_and_cache.py \ --task_config_path ${BASE_PATH}/tasks/configs/${TASK}_config.json \ - --model_type ${MODEL_TYPE} \ - --model_tokenizer_path ${BASE_PATH}/models/${MODEL_TYPE}/tokenizer \ + --hf_pretrained_model_name_or_path ${MODEL_TYPE} \ --output_dir ${BASE_PATH}/cache/${MODEL_TYPE}/${TASK} \ --phases val,test \ --max_seq_length 256 \ @@ -44,8 +41,7 @@ done TASK=udpos_en python jiant/proj/main/tokenize_and_cache.py \ --task_config_path ${BASE_PATH}/tasks/configs/${TASK}_config.json \ - --model_type ${MODEL_TYPE} \ - --model_tokenizer_path ${BASE_PATH}/models/${MODEL_TYPE}/tokenizer \ + --hf_pretrained_model_name_or_path ${MODEL_TYPE} \ --output_dir ${BASE_PATH}/cache/${MODEL_TYPE}/${TASK} \ --phases train,val,test \ --max_seq_length 256 \ @@ -54,8 +50,7 @@ for LG in af ar bg de el es et eu fa fi fr he hi hu id it ja ko mr nl pt ru ta t TASK=udpos_${LG} python jiant/proj/main/tokenize_and_cache.py \ --task_config_path ${BASE_PATH}/tasks/configs/${TASK}_config.json \ - --model_type ${MODEL_TYPE} \ - --model_tokenizer_path ${BASE_PATH}/models/${MODEL_TYPE}/tokenizer \ + --hf_pretrained_model_name_or_path ${MODEL_TYPE} \ --output_dir ${BASE_PATH}/cache/${MODEL_TYPE}/${TASK} \ --phases val,test \ --max_seq_length 256 \ @@ -65,8 +60,7 @@ for LG in kk th tl yo; do TASK=udpos_${LG} python jiant/proj/main/tokenize_and_cache.py \ --task_config_path ${BASE_PATH}/tasks/configs/${TASK}_config.json \ - --model_type ${MODEL_TYPE} \ - --model_tokenizer_path ${BASE_PATH}/models/${MODEL_TYPE}/tokenizer \ + --hf_pretrained_model_name_or_path ${MODEL_TYPE} \ --output_dir ${BASE_PATH}/cache/${MODEL_TYPE}/${TASK} \ --phases test \ --max_seq_length 256 \ @@ -77,8 +71,7 @@ done TASK=panx_en python jiant/proj/main/tokenize_and_cache.py \ --task_config_path ${BASE_PATH}/tasks/configs/${TASK}_config.json \ - --model_type ${MODEL_TYPE} \ - --model_tokenizer_path ${BASE_PATH}/models/${MODEL_TYPE}/tokenizer \ + --hf_pretrained_model_name_or_path ${MODEL_TYPE} \ --output_dir ${BASE_PATH}/cache/${MODEL_TYPE}/${TASK} \ --phases train,val,test \ --max_seq_length 256 \ @@ -87,8 +80,7 @@ for LG in af ar bg bn de el es et eu fa fi fr he hi hu id it ja jv ka kk ko ml m TASK=panx_${LG} python jiant/proj/main/tokenize_and_cache.py \ --task_config_path ${BASE_PATH}/tasks/configs/${TASK}_config.json \ - --model_type ${MODEL_TYPE} \ - --model_tokenizer_path ${BASE_PATH}/models/${MODEL_TYPE}/tokenizer \ + --hf_pretrained_model_name_or_path ${MODEL_TYPE} \ --output_dir ${BASE_PATH}/cache/${MODEL_TYPE}/${TASK} \ --phases val,test \ --max_seq_length 256 \ @@ -100,8 +92,7 @@ for LG in ar de el en es hi ru th tr vi zh; do TASK=xquad_${LG} python jiant/proj/main/tokenize_and_cache.py \ --task_config_path ${BASE_PATH}/tasks/configs/${TASK}_config.json \ - --model_type ${MODEL_TYPE} \ - --model_tokenizer_path ${BASE_PATH}/models/${MODEL_TYPE}/tokenizer \ + --hf_pretrained_model_name_or_path ${MODEL_TYPE} \ --output_dir ${BASE_PATH}/cache/${MODEL_TYPE}/${TASK} \ --phases val \ --max_seq_length 384 \ @@ -113,8 +104,7 @@ for LG in ar de en es hi vi zh; do TASK=mlqa_${LG}_${LG} python jiant/proj/main/tokenize_and_cache.py \ --task_config_path ${BASE_PATH}/tasks/configs/${TASK}_config.json \ - --model_type ${MODEL_TYPE} \ - --model_tokenizer_path ${BASE_PATH}/models/${MODEL_TYPE}/tokenizer \ + --hf_pretrained_model_name_or_path ${MODEL_TYPE} \ --output_dir ${BASE_PATH}/cache/${MODEL_TYPE}/${TASK} \ --phases val,test \ --max_seq_length 384 \ @@ -125,8 +115,7 @@ done TASK=tydiqa_en python jiant/proj/main/tokenize_and_cache.py \ --task_config_path ${BASE_PATH}/tasks/configs/${TASK}_config.json \ - --model_type ${MODEL_TYPE} \ - --model_tokenizer_path ${BASE_PATH}/models/${MODEL_TYPE}/tokenizer \ + --hf_pretrained_model_name_or_path ${MODEL_TYPE} \ --output_dir ${BASE_PATH}/cache/${MODEL_TYPE}/${TASK} \ --phases train,val \ --max_seq_length 384 \ @@ -135,8 +124,7 @@ for LG in ar bn fi id ko ru sw te; do TASK=tydiqa_${LG} python jiant/proj/main/tokenize_and_cache.py \ --task_config_path ${BASE_PATH}/tasks/configs/${TASK}_config.json \ - --model_type ${MODEL_TYPE} \ - --model_tokenizer_path ${BASE_PATH}/models/${MODEL_TYPE}/tokenizer \ + --hf_pretrained_model_name_or_path ${MODEL_TYPE} \ --output_dir ${BASE_PATH}/cache/${MODEL_TYPE}/${TASK} \ --phases val \ --max_seq_length 384 \ @@ -148,8 +136,7 @@ for LG in de fr ru zh; do TASK=bucc2018_${LG} python jiant/proj/main/tokenize_and_cache.py \ --task_config_path ${BASE_PATH}/tasks/configs/${TASK}_config.json \ - --model_type ${MODEL_TYPE} \ - --model_tokenizer_path ${BASE_PATH}/models/${MODEL_TYPE}/tokenizer \ + --hf_pretrained_model_name_or_path ${MODEL_TYPE} \ --output_dir ${BASE_PATH}/cache/${MODEL_TYPE}/${TASK} \ --phases val,test \ --max_seq_length 512 \ @@ -161,8 +148,7 @@ for LG in af ar bg bn de el es et eu fa fi fr he hi hu id it ja jv ka kk ko ml m TASK=tatoeba_${LG} python jiant/proj/main/tokenize_and_cache.py \ --task_config_path ${BASE_PATH}/tasks/configs/${TASK}_config.json \ - --model_type ${MODEL_TYPE} \ - --model_tokenizer_path ${BASE_PATH}/models/${MODEL_TYPE}/tokenizer \ + --hf_pretrained_model_name_or_path ${MODEL_TYPE} \ --output_dir ${BASE_PATH}/cache/${MODEL_TYPE}/${TASK} \ --phases val \ --max_seq_length 512 \ @@ -173,8 +159,7 @@ done TASK=mnli python jiant/proj/main/tokenize_and_cache.py \ --task_config_path ${BASE_PATH}/tasks/configs/${TASK}_config.json \ - --model_type ${MODEL_TYPE} \ - --model_tokenizer_path ${BASE_PATH}/models/${MODEL_TYPE}/tokenizer \ + --hf_pretrained_model_name_or_path ${MODEL_TYPE} \ --output_dir ${BASE_PATH}/cache/${MODEL_TYPE}/${TASK} \ --phases train,val \ --max_seq_length 256 \ @@ -182,8 +167,7 @@ python jiant/proj/main/tokenize_and_cache.py \ TASK=squad_v1 python jiant/proj/main/tokenize_and_cache.py \ --task_config_path ${BASE_PATH}/tasks/configs/${TASK}_config.json \ - --model_type ${MODEL_TYPE} \ - --model_tokenizer_path ${BASE_PATH}/models/${MODEL_TYPE}/tokenizer \ + --hf_pretrained_model_name_or_path ${MODEL_TYPE} \ --output_dir ${BASE_PATH}/cache/${MODEL_TYPE}/${TASK} \ --phases train,val \ --max_seq_length 384 \ diff --git a/jiant/scripts/benchmarks/xtreme/xtreme_submission.py b/jiant/scripts/benchmarks/xtreme/xtreme_submission.py index d07c92194..afdbc97a2 100644 --- a/jiant/scripts/benchmarks/xtreme/xtreme_submission.py +++ b/jiant/scripts/benchmarks/xtreme/xtreme_submission.py @@ -40,7 +40,6 @@ class RunConfiguration(zconf.RunConfig): model_type = zconf.attr(type=str, required=True) model_path = zconf.attr(type=str, required=True) model_config_path = zconf.attr(default=None, type=str) - model_tokenizer_path = zconf.attr(default=None, type=str) model_load_mode = zconf.attr(default="from_ptt", type=str) # === Nuisance Parameters === # diff --git a/jiant/scripts/download_data/constants.py b/jiant/scripts/download_data/constants.py index 8eb648162..badd231c4 100644 --- a/jiant/scripts/download_data/constants.py +++ b/jiant/scripts/download_data/constants.py @@ -2,9 +2,9 @@ # is not suitable SQUAD_TASKS = {"squad_v1", "squad_v2"} DIRECT_SUPERGLUE_TASKS_TO_DATA_URLS = { - "wsc": f"https://dl.fbaipublicfiles.com/glue/superglue/data/v2/WSC.zip", - "multirc": f"https://dl.fbaipublicfiles.com/glue/superglue/data/v2/MultiRC.zip", - "record": f"https://dl.fbaipublicfiles.com/glue/superglue/data/v2/ReCoRD.zip", + "wsc": "https://dl.fbaipublicfiles.com/glue/superglue/data/v2/WSC.zip", + "multirc": "https://dl.fbaipublicfiles.com/glue/superglue/data/v2/MultiRC.zip", + "record": "https://dl.fbaipublicfiles.com/glue/superglue/data/v2/ReCoRD.zip", } OTHER_DOWNLOAD_TASKS = { diff --git a/jiant/shared/model_resolution.py b/jiant/shared/model_resolution.py index e59f86975..7ec446137 100644 --- a/jiant/shared/model_resolution.py +++ b/jiant/shared/model_resolution.py @@ -27,23 +27,23 @@ def from_model_type(cls, model_type: str): Model architecture associated with the provided shortcut name. """ - if model_type.startswith("bert-"): + if model_type.startswith("bert"): return cls.BERT - elif model_type.startswith("xlm-") and not model_type.startswith("xlm-roberta"): + elif model_type.startswith("xlm") and not model_type.startswith("xlm-roberta"): return cls.XLM - elif model_type.startswith("roberta-"): + elif model_type.startswith("roberta"): return cls.ROBERTA - elif model_type.startswith("albert-"): + elif model_type.startswith("albert"): return cls.ALBERT elif model_type == "glove_lstm": return cls.GLOVE_LSTM - elif model_type.startswith("xlm-roberta-"): + elif model_type.startswith("xlm-roberta"): return cls.XLM_ROBERTA - elif model_type.startswith("bart-"): + elif model_type.startswith("bart"): return cls.BART - elif model_type.startswith("mbart-"): + elif model_type.startswith("mbart"): return cls.MBART - elif model_type.startswith("electra-"): + elif model_type.startswith("electra"): return cls.ELECTRA else: raise KeyError(model_type) diff --git a/jiant/shared/model_setup.py b/jiant/shared/model_setup.py index 427e399c5..2433c826b 100644 --- a/jiant/shared/model_setup.py +++ b/jiant/shared/model_setup.py @@ -2,44 +2,6 @@ import torch from jiant.ext.radam import RAdam -from jiant.shared.model_resolution import ModelArchitectures, resolve_tokenizer_class - - -def get_tokenizer(model_type, tokenizer_path): - """Instantiate a tokenizer for a given model type. - - Args: - model_type (str): model shortcut name. - tokenizer_path (str): path to tokenizer directory. - - Returns: - Tokenizer for the given model type. - - """ - model_arch = ModelArchitectures.from_model_type(model_type) - tokenizer_class = resolve_tokenizer_class(model_type) - if model_arch in [ModelArchitectures.BERT]: - if "-cased" in model_type: - do_lower_case = False - elif "-uncased" in model_type: - do_lower_case = True - else: - raise RuntimeError(model_type) - elif model_arch in [ - ModelArchitectures.XLM, - ModelArchitectures.ROBERTA, - ModelArchitectures.XLM_ROBERTA, - ModelArchitectures.BART, - ModelArchitectures.MBART, - ModelArchitectures.ELECTRA, - ]: - do_lower_case = False - elif model_arch in [ModelArchitectures.ALBERT]: - do_lower_case = True - else: - raise RuntimeError(str(tokenizer_class)) - tokenizer = tokenizer_class.from_pretrained(tokenizer_path, do_lower_case=do_lower_case) - return tokenizer class OptimizerScheduler: diff --git a/tests/proj/main/test_export_model.py b/tests/proj/main/test_export_model.py index 5e8a87aae..1dd1cdd8a 100644 --- a/tests/proj/main/test_export_model.py +++ b/tests/proj/main/test_export_model.py @@ -7,27 +7,18 @@ @pytest.mark.parametrize( - "model_type, model_class, tokenizer_class, hf_model_name", + "model_type, model_class, hf_pretrained_model_name_or_path", [ - ("bert-base-cased", BertPreTrainedModel, BertTokenizer, "bert-base-cased"), - ( - "roberta-med-small-1M-1", - RobertaForMaskedLM, - RobertaTokenizer, - "nyu-mll/roberta-med-small-1M-1", - ), + ("bert", BertPreTrainedModel, "bert-base-cased"), + ("roberta", RobertaForMaskedLM, "nyu-mll/roberta-med-small-1M-1",), ], ) -def test_export_model(tmp_path, model_type, model_class, tokenizer_class, hf_model_name): +def test_export_model(tmp_path, model_type, model_class, hf_pretrained_model_name_or_path): export_model( - model_type=model_type, + hf_pretrained_model_name_or_path=hf_pretrained_model_name_or_path, output_base_path=tmp_path, - model_class=model_class, - tokenizer_class=tokenizer_class, - hf_model_name=hf_model_name, ) read_config = py_io.read_json(os.path.join(tmp_path, f"config.json")) assert read_config["model_type"] == model_type assert read_config["model_path"] == os.path.join(tmp_path, "model", f"{model_type}.p") assert read_config["model_config_path"] == os.path.join(tmp_path, "model", f"{model_type}.json") - assert read_config["model_tokenizer_path"] == os.path.join(tmp_path, "tokenizer") diff --git a/tests/proj/simple/test_runscript.py b/tests/proj/simple/test_runscript.py index 981b78e08..a17da99f2 100644 --- a/tests/proj/simple/test_runscript.py +++ b/tests/proj/simple/test_runscript.py @@ -21,7 +21,7 @@ def test_simple_runscript(tmpdir, task_name, model_type): run_name=RUN_NAME, exp_dir=exp_dir, data_dir=data_dir, - model_type=model_type, + hf_pretrained_model_name_or_path=model_type, tasks=task_name, train_examples_cap=16, train_batch_size=16, @@ -47,7 +47,7 @@ def test_simple_runscript_save(tmpdir, task_name, model_type): run_name=run_name, exp_dir=exp_dir, data_dir=data_dir, - model_type=model_type, + hf_pretrained_model_name_or_path=model_type, tasks=task_name, max_steps=1, train_batch_size=32, @@ -78,7 +78,7 @@ def test_simple_runscript_save(tmpdir, task_name, model_type): run_name=run_name, exp_dir=exp_dir, data_dir=data_dir, - model_type=model_type, + hf_pretrained_model_name_or_path=model_type, tasks=task_name, max_steps=1, train_batch_size=16, @@ -98,7 +98,7 @@ def test_simple_runscript_save(tmpdir, task_name, model_type): run_name=run_name, exp_dir=exp_dir, data_dir=data_dir, - model_type=model_type, + hf_pretrained_model_name_or_path=model_type, tasks=task_name, max_steps=1, train_batch_size=16,