From f8899a2c3beb1bec0384f4cedf0437fd043e1435 Mon Sep 17 00:00:00 2001 From: ZYHowell Date: Sat, 29 Apr 2023 20:50:23 +0000 Subject: [PATCH 01/11] initial update --- examples/opt_finetune/run_easylm_flax.py | 1006 ++++++++++++++++++++++ examples/opt_finetune/run_llama.sh | 21 + 2 files changed, 1027 insertions(+) create mode 100644 examples/opt_finetune/run_easylm_flax.py create mode 100644 examples/opt_finetune/run_llama.sh diff --git a/examples/opt_finetune/run_easylm_flax.py b/examples/opt_finetune/run_easylm_flax.py new file mode 100644 index 000000000..43e797ad3 --- /dev/null +++ b/examples/opt_finetune/run_easylm_flax.py @@ -0,0 +1,1006 @@ +# TODO: +# 1. Import Llama Model Definition(done); +# 2. Import Manual partition spec; +# 3. Import Fastchat dataset; +# 4. Weight Conversion(done); +# 5. Distributed load/store. + +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2021 The HuggingFace Team All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Pre-training/Fine-tuning the library models for causal language modeling (GPT, GPT-2, CTRL, ...) on a text file or a dataset. + +Here is the full list of checkpoints on the hub that can be fine-tuned by this script: +https://huggingface.co/models?filter=text-generation +""" +# You can also adapt this script on your own causal language modeling task. Pointers for this are left as comments. + +import json +import logging +import math +import os +import sys +import time +from dataclasses import asdict, dataclass, field +from enum import Enum +import functools +from itertools import chain +from pathlib import Path +from typing import Callable, Optional + +import datasets +import numpy as np +from datasets import Dataset, load_dataset +from tqdm import tqdm + +import alpa +from alpa.model.model_util import DynamicScale, TrainState +from alpa import AutoShardingOption, AutoLayerOption, ManualStageOption +import jax +import jax.numpy as jnp +import optax +import transformers +import tensorflow as tf +from flax import traverse_util +from huggingface_hub import Repository +from transformers import ( + CONFIG_MAPPING, + FLAX_MODEL_FOR_CAUSAL_LM_MAPPING, + AutoConfig, + AutoTokenizer, + FlaxAutoModelForCausalLM, + HfArgumentParser, + is_tensorboard_available, + set_seed, +) + +alpa.init(cluster="ray") + +from transformers.testing_utils import CaptureLogger +from transformers.utils import get_full_repo_name, send_example_telemetry + +tf.config.experimental.set_visible_devices([], 'GPU') + +from EasyLM.EasyLM.models.llama.llama_model import ( + LLaMAConfig, FlaxLLaMAForCausalLMModule, FlaxLLaMAForCausalLM +) + +logger = logging.getLogger(__name__) + +MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_CAUSAL_LM_MAPPING.keys()) +MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES) + + +@dataclass +class TrainingArguments: + output_dir: str = field( + metadata={"help": "The output directory where the model predictions and checkpoints will be written."}, + ) + overwrite_output_dir: bool = field( + default=False, + metadata={ + "help": ( + "Overwrite the content of the output directory. " + "Use this to continue training if output_dir points to a checkpoint directory." + ) + }, + ) + do_train: bool = field(default=False, metadata={"help": "Whether to run training."}) + do_eval: bool = field(default=False, metadata={"help": "Whether to run eval on the dev set."}) + per_device_train_batch_size: int = field( + default=8, metadata={"help": "Batch size per GPU/TPU core/CPU for training."} + ) + per_device_eval_batch_size: int = field( + default=8, metadata={"help": "Batch size per GPU/TPU core/CPU for evaluation."} + ) + num_micro_batches: int = field(default=1, metadata={"help": "The number of micro batches for gradient accumulation."}) + operator_parallel: int = field(default=1, metadata={"help": "The degree of operator model parallelism."}) + pipeline_parallel: int = field(default=1, metadata={"help": "The degree of pipeline model parallelism."}) + use_remat: bool = field(default=True, metadata={"help": "Whether or not to use gradient rematerilization/gradient checkpointing."}) + learning_rate: float = field(default=5e-5, metadata={"help": "The initial learning rate for AdamW."}) + weight_decay: float = field(default=0.0, metadata={"help": "Weight decay for AdamW if we apply some."}) + adam_beta1: float = field(default=0.9, metadata={"help": "Beta1 for AdamW optimizer"}) + adam_beta2: float = field(default=0.999, metadata={"help": "Beta2 for AdamW optimizer"}) + adam_epsilon: float = field(default=1e-8, metadata={"help": "Epsilon for AdamW optimizer."}) + adafactor: bool = field(default=False, metadata={"help": "Whether or not to replace AdamW by Adafactor."}) + num_train_epochs: float = field(default=3.0, metadata={"help": "Total number of training epochs to perform."}) + warmup_steps: int = field(default=0, metadata={"help": "Linear warmup over warmup_steps."}) + logging_steps: int = field(default=500, metadata={"help": "Log every X updates steps."}) + save_steps: int = field(default=500, metadata={"help": "Save checkpoint every X updates steps."}) + eval_steps: int = field(default=None, metadata={"help": "Run an evaluation every X steps."}) + seed: int = field(default=42, metadata={"help": "Random seed that will be set at the beginning of training."}) + push_to_hub: bool = field( + default=False, metadata={"help": "Whether or not to upload the trained model to the model hub after training."} + ) + hub_model_id: str = field( + default=None, metadata={"help": "The name of the repository to keep in sync with the local `output_dir`."} + ) + hub_token: str = field(default=None, metadata={"help": "The token to use to push to the Model Hub."}) + + def __post_init__(self): + if self.output_dir is not None: + self.output_dir = os.path.expanduser(self.output_dir) + + def to_dict(self): + """ + Serializes this instance while replace `Enum` by their values (for JSON serialization support). It obfuscates + the token values by removing their value. + """ + d = asdict(self) + for k, v in d.items(): + if isinstance(v, Enum): + d[k] = v.value + if isinstance(v, list) and len(v) > 0 and isinstance(v[0], Enum): + d[k] = [x.value for x in v] + if k.endswith("_token"): + d[k] = f"<{k.upper()}>" + return d + + +@dataclass +class ModelArguments: + """ + Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch. + """ + + model_name_or_path: Optional[str] = field( + default=None, + metadata={ + "help": ( + "The model checkpoint for weights initialization.Don't set if you want to train a model from scratch." + ) + }, + ) + model_type: Optional[str] = field( + default=None, + metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)}, + ) + config_name: Optional[str] = field( + default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} + ) + tokenizer_name: Optional[str] = field( + default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} + ) + cache_dir: Optional[str] = field( + default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"} + ) + use_fast_tokenizer: bool = field( + default=True, + metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."}, + ) + dtype: Optional[str] = field( + default="float32", + metadata={ + "help": ( + "Floating-point format in which the model weights should be initialized and trained. Choose one of" + " `[float32, float16, bfloat16]`." + ) + }, + ) + use_auth_token: bool = field( + default=False, + metadata={ + "help": ( + "Will use the token generated when running `transformers-cli login` (necessary to use this script " + "with private models)." + ) + }, + ) + + +@dataclass +class DataTrainingArguments: + """ + Arguments pertaining to what data we are going to input our model for training and eval. + """ + + dataset_name: Optional[str] = field( + default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."} + ) + dataset_config_name: Optional[str] = field( + default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."} + ) + train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."}) + validation_file: Optional[str] = field( + default=None, + metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."}, + ) + max_train_samples: Optional[int] = field( + default=None, + metadata={ + "help": ( + "For debugging purposes or quicker training, truncate the number of training examples to this " + "value if set." + ) + }, + ) + max_eval_samples: Optional[int] = field( + default=None, + metadata={ + "help": ( + "For debugging purposes or quicker training, truncate the number of evaluation examples to this " + "value if set." + ) + }, + ) + overwrite_cache: bool = field( + default=False, metadata={"help": "Overwrite the cached training and evaluation sets"} + ) + validation_split_percentage: Optional[int] = field( + default=5, + metadata={ + "help": "The percentage of the train set used as validation set in case there's no validation split" + }, + ) + block_size: Optional[int] = field( + default=None, + metadata={ + "help": ( + "Optional input sequence length after tokenization. " + "The training dataset will be truncated in block of this size for training. " + "Default to the model max input length for single sentence inputs (take into account special tokens)." + ) + }, + ) + overwrite_cache: bool = field( + default=False, metadata={"help": "Overwrite the cached training and evaluation sets"} + ) + preprocessing_num_workers: Optional[int] = field( + default=None, + metadata={"help": "The number of processes to use for the preprocessing."}, + ) + keep_linebreaks: bool = field( + default=True, metadata={"help": "Whether to keep line breaks when using TXT files or not."} + ) + + def __post_init__(self): + if self.dataset_name is None and self.train_file is None and self.validation_file is None: + raise ValueError("Need either a dataset name or a training/validation file.") + else: + if self.train_file is not None: + extension = self.train_file.split(".")[-1] + assert extension in ["csv", "json", "txt"], "`train_file` should be a csv, a json or a txt file." + if self.validation_file is not None: + extension = self.validation_file.split(".")[-1] + assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, a json or a txt file." + + +def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, + min_batch_size: int, shuffle: bool = False): + """ + Returns batches of size `batch_size` from truncated `dataset`, sharded over all local devices. + Shuffle batches if `shuffle` is `True`. + """ + if len(dataset) < batch_size: + assert len(dataset) >= min_batch_size + batch_size = len(dataset) // min_batch_size * min_batch_size + + data_collator = transformers.DefaultDataCollator("np") + tf_dataset = dataset.to_tf_dataset(batch_size=batch_size, + columns=dataset.column_names, + collate_fn=data_collator, + shuffle=shuffle, + drop_remainder=True) + + for batch in tf_dataset: + batch = {k: v._numpy() for k, v in batch.items()} + yield batch + + +def write_train_metric(summary_writer, train_metrics, train_time, step): + summary_writer.scalar("train_time", train_time, step) + + train_metrics = alpa.util.get_metrics(train_metrics) + for key, vals in train_metrics.items(): + tag = f"train_{key}" + for i, val in enumerate(vals): + summary_writer.scalar(tag, val, step - len(vals) + i + 1) + + +def write_eval_metric(summary_writer, eval_metrics, step): + for metric_name, value in eval_metrics.items(): + summary_writer.scalar(f"eval_{metric_name}", value, step) + + +def create_learning_rate_fn( + train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float +) -> Callable[[int], jnp.array]: + """Returns a linear warmup, linear_decay learning rate function.""" + steps_per_epoch = train_ds_size // train_batch_size + num_train_steps = steps_per_epoch * num_train_epochs + warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps) + decay_fn = optax.linear_schedule( + init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps + ) + schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps]) + return schedule_fn + + +def monkey_patch_remat(): + # Use monkey patch to add remat for all transformer layers. + from transformers.models.opt.modeling_flax_opt import FlaxOPTDecoderLayer, FlaxOPTDecoderLayerCollection + from flax.linen.partitioning import remat + from flax.linen.module import wrap_method_once + import flax.linen as nn + + @wrap_method_once + def setup(self): + self.layers = [ + remat(FlaxOPTDecoderLayer, static_argnums=(2, 3, 4))( + self.config, name=str(i), dtype=self.dtype) + for i in range(self.config.num_hidden_layers) + ] + self.layerdrop = self.config.layerdrop + + def call( + self, + hidden_states, + attention_mask, + deterministic: bool = True, + init_cache: bool = False, + output_attentions: bool = False, + output_hidden_states: bool = False, + ): + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + layer_outputs = decoder_layer( + hidden_states, + attention_mask, + init_cache, + output_attentions, + deterministic, + ) + + hidden_states = layer_outputs[0] + if output_attentions: + all_self_attns += (layer_outputs[1],) + + outputs = [hidden_states, all_hidden_states, all_self_attns] + return outputs + + setattr(FlaxOPTDecoderLayerCollection, "setup", setup) + setattr(FlaxOPTDecoderLayerCollection, "__call__", call) + + +def main(): + # See all possible arguments in src/transformers/training_args.py + # or by passing the --help flag to this script. + # We now keep distinct sets of args, for a cleaner separation of concerns. + + parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments)) + if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): + # If we pass only one argument to the script and it's the path to a json file, + # let's parse it to get our arguments. + model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) + else: + model_args, data_args, training_args = parser.parse_args_into_dataclasses() + + # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The + # information sent is the one passed as arguments along with your Python/PyTorch versions. + send_example_telemetry("run_clm", model_args, data_args, framework="flax") + + if ( + os.path.exists(training_args.output_dir) + and os.listdir(training_args.output_dir) + and training_args.do_train + and not training_args.overwrite_output_dir + ): + raise ValueError( + f"Output directory ({training_args.output_dir}) already exists and is not empty." + "Use --overwrite_output_dir to overcome." + ) + + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + # Setup logging, we only want one process per machine to log things on the screen. + logger.setLevel(logging.INFO) + datasets.utils.logging.set_verbosity_warning() + transformers.utils.logging.set_verbosity_info() + + # Set the verbosity to info of the Transformers logger (on main process only): + logger.info(f"Training/evaluation parameters {training_args}") + + # Set seed before initializing model. + set_seed(training_args.seed) + + # Handle the repository creation + if training_args.push_to_hub: + if training_args.hub_model_id is None: + repo_name = get_full_repo_name( + Path(training_args.output_dir).absolute().name, token=training_args.hub_token + ) + else: + repo_name = training_args.hub_model_id + repo = Repository(training_args.output_dir, clone_from=repo_name) + + # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below) + # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/ + # (the dataset will be downloaded automatically from the datasets Hub). + # + # For CSV/JSON files, this script will use the column called 'text' or the first column if no column called + # 'text' is found. You can easily tweak this behavior (see below). + # + # In distributed training, the load_dataset function guarantees that only one local process can concurrently + # download the dataset. + if data_args.dataset_name is not None: + # Downloading and loading a dataset from the hub. + dataset = load_dataset( + data_args.dataset_name, + data_args.dataset_config_name, + cache_dir=model_args.cache_dir, + keep_in_memory=False, + use_auth_token=True if model_args.use_auth_token else None, + ) + + if "validation" not in dataset.keys(): + dataset["validation"] = load_dataset( + data_args.dataset_name, + data_args.dataset_config_name, + split=f"train[:{data_args.validation_split_percentage}%]", + cache_dir=model_args.cache_dir, + use_auth_token=True if model_args.use_auth_token else None, + ) + dataset["train"] = load_dataset( + data_args.dataset_name, + data_args.dataset_config_name, + split=f"train[{data_args.validation_split_percentage}%:]", + cache_dir=model_args.cache_dir, + use_auth_token=True if model_args.use_auth_token else None, + ) + else: + data_files = {} + dataset_args = {} + if data_args.train_file is not None: + data_files["train"] = data_args.train_file + if data_args.validation_file is not None: + data_files["validation"] = data_args.validation_file + extension = data_args.train_file.split(".")[-1] + if extension == "txt": + extension = "text" + dataset_args["keep_linebreaks"] = data_args.keep_linebreaks + dataset = load_dataset( + extension, + data_files=data_files, + cache_dir=model_args.cache_dir, + **dataset_args, + use_auth_token=True if model_args.use_auth_token else None, + ) + + if "validation" not in dataset.keys(): + dataset["validation"] = load_dataset( + extension, + data_files=data_files, + split=f"train[:{data_args.validation_split_percentage}%]", + cache_dir=model_args.cache_dir, + **dataset_args, + use_auth_token=True if model_args.use_auth_token else None, + ) + dataset["train"] = load_dataset( + extension, + data_files=data_files, + split=f"train[{data_args.validation_split_percentage}%:]", + cache_dir=model_args.cache_dir, + **dataset_args, + use_auth_token=True if model_args.use_auth_token else None, + ) + # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at + # https://huggingface.co/docs/datasets/loading_datasets.html. + + # Load pretrained model and tokenizer + + # Distributed training: + # The .from_pretrained methods guarantee that only one local process can concurrently + # download model & vocab. + # if model_args.config_name: + # config = AutoConfig.from_pretrained( + # model_args.config_name, + # cache_dir=model_args.cache_dir, + # use_auth_token=True if model_args.use_auth_token else None, + # ) + # elif model_args.model_name_or_path: + # config = AutoConfig.from_pretrained( + # model_args.model_name_or_path, + # cache_dir=model_args.cache_dir, + # use_auth_token=True if model_args.use_auth_token else None, + # ) + # else: + # config = CONFIG_MAPPING[model_args.model_type]() + # logger.warning("You are instantiating a new config instance from scratch.") + # TODO: merge with the above + config = LLaMAConfig.load_config('test') + + if training_args.use_remat: + monkey_patch_remat() + + if model_args.tokenizer_name: + tokenizer = AutoTokenizer.from_pretrained( + model_args.tokenizer_name, + cache_dir=model_args.cache_dir, + use_fast=model_args.use_fast_tokenizer, + use_auth_token=True if model_args.use_auth_token else None, + ) + elif model_args.model_name_or_path: + tokenizer = AutoTokenizer.from_pretrained( + model_args.model_name_or_path, + cache_dir=model_args.cache_dir, + #use_fast=model_args.use_fast_tokenizer, + use_auth_token=True if model_args.use_auth_token else None, + use_fast=False, + ) + else: + raise ValueError( + "You are instantiating a new tokenizer from scratch. This is not supported by this script." + "You can do it from another script, save it, and load it from here, using --tokenizer_name." + ) + + # if model_args.model_name_or_path: + # model = FlaxAutoModelForCausalLM.from_pretrained( + # model_args.model_name_or_path, + # config=config, + # seed=training_args.seed, + # dtype=getattr(jnp, model_args.dtype), + # use_auth_token=True if model_args.use_auth_token else None, + # ) + # #from transformers import FlaxOPTForCausalLM + # #config.num_hidden_layers = 2 + # #model = FlaxOPTForCausalLM( + # # config=config, + # # seed=training_args.seed, + # # dtype=getattr(jnp, model_args.dtype), + # #) + # else: + # model = FlaxAutoModelForCausalLM.from_config( + # config, + # seed=training_args.seed, + # dtype=getattr(jnp, model_args.dtype), + # ) + model = FlaxLLaMAForCausalLM(config, (4, 2048)) + + # Preprocessing the datasets. + # First we tokenize all the texts. + if training_args.do_train: + column_names = dataset["train"].column_names + else: + column_names = dataset["validation"].column_names + text_column_name = "text" if "text" in column_names else column_names[0] + + # since this will be pickled to avoid _LazyModule error in Hasher force logger loading before tokenize_function + tok_logger = transformers.utils.logging.get_logger("transformers.tokenization_utils_base") + + def tokenize_function(examples): + with CaptureLogger(tok_logger) as cl: + output = tokenizer(examples[text_column_name]) + # clm input could be much much longer than block_size + if "Token indices sequence length is longer than the" in cl.out: + tok_logger.warning( + "^^^^^^^^^^^^^^^^ Please ignore the warning above - this long input will be chunked into smaller bits" + " before being passed to the model." + ) + return output + + logger.info("***** Tokenize dataset *****") + tokenized_datasets = dataset.map( + tokenize_function, + batched=True, + num_proc=data_args.preprocessing_num_workers, + remove_columns=column_names, + load_from_cache_file=not data_args.overwrite_cache, + ) + + if data_args.block_size is None: + block_size = tokenizer.model_max_length + if block_size > config.max_position_embeddings: + logger.warning( + f"The tokenizer picked seems to have a very large `model_max_length` ({tokenizer.model_max_length}). " + "Picking 1024 instead. You can change that default value by passing --block_size xxx." + ) + block_size = 1024 + else: + if data_args.block_size > tokenizer.model_max_length: + logger.warning( + f"The block_size passed ({data_args.block_size}) is larger than the maximum length for the model" + f"({tokenizer.model_max_length}). Using block_size={tokenizer.model_max_length}." + ) + block_size = min(data_args.block_size, tokenizer.model_max_length) + + # Main data processing function that will concatenate all texts from our dataset and generate chunks of block_size. + def group_texts(examples): + # Concatenate all texts. + concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()} + total_length = len(concatenated_examples[list(examples.keys())[0]]) + # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can + # customize this part to your needs. + if total_length >= block_size: + total_length = (total_length // block_size) * block_size + # Split by chunks of max_len. + result = { + k: [t[i : i + block_size] for i in range(0, total_length, block_size)] + for k, t in concatenated_examples.items() + } + result["labels"] = result["input_ids"].copy() + return result + + # Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a remainder + # for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value might be slower + # to preprocess. + # + # To speed up this part, we use multiprocessing. See the documentation of the map method for more information: + # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map + + logger.info("***** Build dataset *****") + lm_datasets = tokenized_datasets.map( + group_texts, + batched=True, + num_proc=data_args.preprocessing_num_workers, + load_from_cache_file=not data_args.overwrite_cache, + ) + + if training_args.do_train: + if "train" not in tokenized_datasets: + raise ValueError("--do_train requires a train dataset") + train_dataset = lm_datasets["train"] + if data_args.max_train_samples is not None: + max_train_samples = min(len(train_dataset), data_args.max_train_samples) + train_dataset = train_dataset.select(range(max_train_samples)) + + if training_args.do_eval: + if "validation" not in tokenized_datasets: + raise ValueError("--do_eval requires a validation dataset") + eval_dataset = lm_datasets["validation"] + if data_args.max_eval_samples is not None: + max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples) + eval_dataset = eval_dataset.select(range(max_eval_samples)) + + # Adjust batch size and num_micro_batches for small datasets + num_devices = alpa.get_global_num_devices() + train_min_batch_size = (num_devices // training_args.operator_parallel // + training_args.pipeline_parallel * training_args.num_micro_batches) + eval_num_micro_batches = training_args.num_micro_batches + eval_min_batch_size = (num_devices // training_args.operator_parallel // + training_args.pipeline_parallel * eval_num_micro_batches) + while len(eval_dataset) < eval_min_batch_size: + eval_num_micro_batches //= 2 + eval_min_batch_size = (num_devices // training_args.operator_parallel // + training_args.pipeline_parallel * eval_num_micro_batches) + + # Enable tensorboard only on the master node + has_tensorboard = is_tensorboard_available() + if has_tensorboard: + try: + from flax.metrics.tensorboard import SummaryWriter + + summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir)) + except ImportError as ie: + has_tensorboard = False + logger.warning( + f"Unable to display metrics through TensorBoard because some package are not installed: {ie}" + ) + else: + logger.warning( + "Unable to display metrics through TensorBoard because the package is not installed: " + "Please run pip install tensorboard to enable." + ) + + # Initialize our training + rng = jax.random.PRNGKey(training_args.seed) + rng, dropout_rng = jax.random.split(rng) + + # Store some constant + num_epochs = int(training_args.num_train_epochs) + train_batch_size = int(training_args.per_device_train_batch_size) * num_devices + eval_batch_size = int(training_args.per_device_eval_batch_size) * num_devices + steps_per_epoch = len(train_dataset) // train_batch_size + total_train_steps = steps_per_epoch * num_epochs + + # Create learning rate schedule + linear_decay_lr_schedule_fn = create_learning_rate_fn( + len(train_dataset), + train_batch_size, + training_args.num_train_epochs, + training_args.warmup_steps, + training_args.learning_rate, + ) + + # We use Optax's "masking" functionality to not apply weight decay + # to bias and LayerNorm scale parameters. decay_mask_fn returns a + # mask boolean with the same structure as the parameters. + # The mask is True for parameters that should be decayed. + # Note that this mask is specifically adapted for FlaxGPT2. + # For other models, one should correct the layer norm parameter naming + # accordingly. + def decay_mask_fn(params): + flat_params = traverse_util.flatten_dict(params) + flat_mask = { + path: (path[-1] != "bias" and path[-2:] not in [("ln_1", "scale"), ("ln_2", "scale"), ("ln_f", "scale")]) + for path in flat_params + } + return traverse_util.unflatten_dict(flat_mask) + + # create adam optimizer + if training_args.adafactor: + # We use the default parameters here to initialize adafactor, + # For more details about the parameters please check https://github.com/deepmind/optax/blob/ed02befef9bf81cbbf236be3d2b0e032e9ed4a40/optax/_src/alias.py#L74 + optimizer = optax.adafactor( + learning_rate=linear_decay_lr_schedule_fn, + ) + else: + optimizer = optax.chain( + optax.clip_by_global_norm(1.0), + optax.adamw( + learning_rate=linear_decay_lr_schedule_fn, + b1=training_args.adam_beta1, + b2=training_args.adam_beta2, + eps=training_args.adam_epsilon, + weight_decay=training_args.weight_decay, + mask=decay_mask_fn) + ) + + # Setup train state + if model_args.dtype == "float16": + use_master_copy = True + dynamic_scale = DynamicScale() + # Fix a bug in huggingface's implementation (https://github.com/huggingface/transformers/pull/18462) + alpa.global_config.flax_always_use_fp16_embedding = True + else: + use_master_copy = dynamic_scale = None + state = TrainState.create(apply_fn=model.__call__, params=model.params, tx=optimizer, + dynamic_scale=dynamic_scale, use_master_copy=use_master_copy) + + def loss_fn(logits, labels): + shift_logits = logits[..., :-1, :] + shift_labels = labels[..., 1:] + loss = optax.softmax_cross_entropy( + shift_logits, + jax.nn.one_hot(shift_labels, logits.shape[-1])) + return loss.mean() + + # Define gradient update step fn + def train_step(state, batch): + + def compute_loss(params): + labels = batch.pop("labels") + logits = state.apply_fn(**batch, params=params, deterministic=True)[0] + loss = loss_fn(logits, labels) + return loss + + dynamic_scale = state.dynamic_scale + if dynamic_scale: + grad_fn = dynamic_scale.value_and_grad(compute_loss) + dynamic_scale, is_fin, loss, grads = grad_fn(state.params) + else: + grad_fn = alpa.value_and_grad(compute_loss) + loss, grads = grad_fn(state.params) + + new_state = state.apply_gradients(grads=grads) + + if dynamic_scale: + new_state = new_state.replace( + opt_state=jax.tree_map( + functools.partial(jnp.where, is_fin), + new_state.opt_state, state.opt_state), + params=jax.tree_map( + functools.partial(jnp.where, is_fin), + new_state.params, state.params), + master_copy=jax.tree_map( + functools.partial(jnp.where, is_fin), + new_state.master_copy, state.master_copy), + dynamic_scale=dynamic_scale) + + metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)} + + return new_state, metrics + + # Define eval fn + def eval_step(params, batch): + labels = batch.pop("labels") + logits = model(**batch, params=params, deterministic=True)[0] + loss = loss_fn(logits, labels) + + # summarize metrics + metrics = {"loss": loss} + return metrics + + # Create parallel version of the train and eval step + method = alpa.get_3d_parallel_method( + num_micro_batches=training_args.num_micro_batches, + data_parallel=-1, + operator_parallel=training_args.operator_parallel, + pipeline_parallel=training_args.pipeline_parallel) + + p_train_step = alpa.parallelize(train_step, + method=method, + donate_argnums=(0,)) + p_eval_step = alpa.parallelize(eval_step, + method=alpa.FollowParallel( + p_train_step, num_micro_batches=eval_num_micro_batches)) + + dump_debug_info_train_step = dump_debug_info_eval_step = True + + logger.info("***** Running training *****") + logger.info(f" Num examples = {len(train_dataset)}") + logger.info(f" Num Epochs = {num_epochs}") + logger.info(f" Batch size per device (w. accumulation) = {training_args.per_device_train_batch_size}") + logger.info(f" Global train batch size (w. parallel & distributed) = {train_batch_size}") + logger.info(f" Total optimization steps = {total_train_steps}") + + train_time = 0 + train_metrics = [] + epochs = tqdm(range(num_epochs), desc="Epoch ... ", position=0) + + step_ct = 0 + last_time = time.time() + + epochs.write("Initial compilation. This might take some minutes...") + + for epoch in epochs: + # ======================== Training ================================ + train_start = time.time() + + # Create sampling rng + rng, input_rng = jax.random.split(rng) + + # Generate an epoch by shuffling sampling indices from the train dataset + train_loader = data_loader(input_rng, train_dataset, train_batch_size, + train_min_batch_size, shuffle=True) + steps_per_epoch = len(train_dataset) // train_batch_size + # train + for step in tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False): + batch = next(train_loader) + batch["position_ids"] = (batch["attention_mask"].cumsum(axis=1) * + batch["attention_mask"]) - 1 + state, train_metric = p_train_step(state, batch) + train_metrics.append(train_metric) + + cur_step = epoch * (len(train_dataset) // train_batch_size) + step + + if dump_debug_info_train_step: + dump_debug_info_train_step = False + executable = p_train_step.get_last_executable() + executable.sync() + executable.dump_debug_info("alpa_debug_info") + epochs.write(f"Initial compilation completed. " + f"Time elapsed: {time.time() - train_start:.2f} s") + + step_ct += 1 + if cur_step % training_args.logging_steps == 0 and cur_step > 0: + executable.sync() + latency = (time.time() - last_time) / step_ct + throughput_tokens = np.prod(batch["input_ids"].shape) / latency + throughput_tflops = alpa.util.compute_gpt_tflops( + batch_size=batch["input_ids"].shape[0], + seq_len=batch["input_ids"].shape[1], + num_layers=config.num_hidden_layers, + hidden_size=config.hidden_size, + vocab_size=config.vocab_size, + num_gpus=alpa.get_global_num_devices(), + latency=latency) + step_ct = 0 + + # Save metrics + train_time += time.time() - train_start + if has_tensorboard: + write_train_metric(summary_writer, train_metrics, train_time, cur_step) + + train_metric = jax.tree_map(np.mean, train_metric) + + epochs.write( + f"Step... {cur_step} | " + f"Loss: {train_metric['loss'].mean():.4f}, " + f"Learning Rate: {train_metric['learning_rate'].mean():.5f}, " + f"Throughput: {throughput_tokens:.2f} token/s, " + f"{throughput_tflops:.2f} TFLOP/s" + ) + + train_metrics = [] + last_time = time.time() + + if cur_step % training_args.eval_steps == 0 and cur_step > 0: + # ======================== Evaluating ============================== + eval_metrics = [] + eval_loader = data_loader(input_rng, eval_dataset, eval_batch_size, + eval_min_batch_size) + eval_steps = max(len(eval_dataset) // eval_batch_size, 1) + for _ in tqdm(range(eval_steps), desc="Evaluating...", position=2, leave=False): + # Model forward + batch = next(eval_loader) + batch["position_ids"] = (batch["attention_mask"].cumsum(axis=1) * + batch["attention_mask"]) - 1 + metrics = p_eval_step(state.params, batch) + eval_metrics.append(metrics) + + if dump_debug_info_eval_step: + dump_debug_info_eval_step = False + executable = p_eval_step.get_last_executable() + executable.dump_debug_info("alpa_debug_info") + + # normalize eval metrics + eval_metrics = alpa.util.get_metrics(eval_metrics) + eval_metrics = jax.tree_map(jnp.mean, eval_metrics) + + try: + eval_metrics["perplexity"] = math.exp(eval_metrics["loss"]) + except OverflowError: + eval_metrics["perplexity"] = float("inf") + + # Print metrics and update progress bar + desc = ( + f"Step... ({cur_step} | Eval Loss: {eval_metrics['loss']} | Eval Perplexity:" + f" {eval_metrics['perplexity']})" + ) + epochs.write(desc) + + # Save metrics + if has_tensorboard: + write_eval_metric(summary_writer, eval_metrics, cur_step) + + if cur_step % training_args.save_steps == 0 and cur_step > 0: + # save checkpoint after each epoch and push checkpoint to the hub + epochs.write("\nSave checkpoint...") + alpa.prefetch(state.params) + params = alpa.util.map_to_nparray(state.params) + model.save_pretrained(training_args.output_dir, params=params) + tokenizer.save_pretrained(training_args.output_dir) + if training_args.push_to_hub: + repo.push_to_hub(commit_message=f"Saving weights and logs of step {cur_step}", blocking=False) + + # Eval after training + if training_args.do_eval: + eval_metrics = [] + eval_loader = data_loader(input_rng, eval_dataset, eval_batch_size, + eval_min_batch_size) + eval_steps = max(len(eval_dataset) // eval_batch_size, 1) + for _ in tqdm(range(eval_steps), desc="Evaluating...", position=2, leave=False): + # Model forward + batch = next(eval_loader) + batch["position_ids"] = (batch["attention_mask"].cumsum(axis=1) * + batch["attention_mask"]) - 1 + metrics = p_eval_step(state.params, batch) + eval_metrics.append(metrics) + + # normalize eval metrics + eval_metrics = alpa.util.get_metrics(eval_metrics) + eval_metrics = jax.tree_map(lambda x: jnp.mean(x).item(), eval_metrics) + + try: + eval_metrics["perplexity"] = math.exp(eval_metrics["loss"]) + except OverflowError: + eval_metrics["perplexity"] = float("inf") + + eval_metrics = {f"eval_{metric_name}": value for metric_name, value in eval_metrics.items()} + path = os.path.join(training_args.output_dir, "eval_results.json") + with open(path, "w") as f: + json.dump(eval_metrics, f, indent=4, sort_keys=True) + + # Save the final model + epochs.write("\nSave the final model...") + alpa.prefetch(state.params) + params = alpa.util.map_to_nparray(state.params) + model.save_pretrained(training_args.output_dir, params=params) + tokenizer.save_pretrained(training_args.output_dir) + + +if __name__ == "__main__": + main() diff --git a/examples/opt_finetune/run_llama.sh b/examples/opt_finetune/run_llama.sh new file mode 100644 index 000000000..65187f12f --- /dev/null +++ b/examples/opt_finetune/run_llama.sh @@ -0,0 +1,21 @@ +export PYTHONPATH=$HOME/alpa-proj/EasyLM:$PYTHONPATH +python3 run_easylm_flax.py \ + --output_dir="./output" \ + --model_name_or_path="$HOME/alpa-proj/llama-7b" \ + --dataset_name="wikitext" \ + --dataset_config_name="wikitext-2-raw-v1" \ + --do_train --do_eval \ + --block_size="1024" \ + --per_device_train_batch_size="32" \ + --per_device_eval_batch_size="32" \ + --num_micro_batches 64 \ + --operator_parallel 1 \ + --pipeline_parallel 4 \ + --dtype="float16" \ + --learning_rate="5e-4" --warmup_steps="2000" \ + --adam_beta1="0.9" --adam_beta2="0.98" --weight_decay="0.0" \ + --overwrite_output_dir \ + --num_train_epochs="10" \ + --logging_steps="5" \ + --save_steps="40" \ + --eval_steps="25" From ad6000ee8101d9c498817569183eeef82fbeb9c1 Mon Sep 17 00:00:00 2001 From: ZYHowell Date: Sun, 30 Apr 2023 01:07:16 +0000 Subject: [PATCH 02/11] support manual sharding --- examples/opt_finetune/run_easylm_flax.py | 111 +++++++++++++---------- examples/opt_finetune/run_llama.sh | 4 +- 2 files changed, 65 insertions(+), 50 deletions(-) diff --git a/examples/opt_finetune/run_easylm_flax.py b/examples/opt_finetune/run_easylm_flax.py index 43e797ad3..d3c84b1cc 100644 --- a/examples/opt_finetune/run_easylm_flax.py +++ b/examples/opt_finetune/run_easylm_flax.py @@ -1,9 +1,10 @@ # TODO: # 1. Import Llama Model Definition(done); -# 2. Import Manual partition spec; +# 2. Import Manual partition spec(done); # 3. Import Fastchat dataset; # 4. Weight Conversion(done); # 5. Distributed load/store. +# 6. wandb support #!/usr/bin/env python # coding=utf-8 @@ -48,20 +49,21 @@ import alpa from alpa.model.model_util import DynamicScale, TrainState -from alpa import AutoShardingOption, AutoLayerOption, ManualStageOption +from alpa import ManualShardingOption import jax +from jax.experimental.pjit import PartitionSpec import jax.numpy as jnp import optax import transformers +from transformers.testing_utils import CaptureLogger +from transformers.utils import get_full_repo_name, send_example_telemetry import tensorflow as tf from flax import traverse_util +from optax import tree_map_params from huggingface_hub import Repository from transformers import ( - CONFIG_MAPPING, FLAX_MODEL_FOR_CAUSAL_LM_MAPPING, - AutoConfig, AutoTokenizer, - FlaxAutoModelForCausalLM, HfArgumentParser, is_tensorboard_available, set_seed, @@ -69,12 +71,9 @@ alpa.init(cluster="ray") -from transformers.testing_utils import CaptureLogger -from transformers.utils import get_full_repo_name, send_example_telemetry - tf.config.experimental.set_visible_devices([], 'GPU') -from EasyLM.EasyLM.models.llama.llama_model import ( +from EasyLM.models.llama.llama_model import ( LLaMAConfig, FlaxLLaMAForCausalLMModule, FlaxLLaMAForCausalLM ) @@ -381,6 +380,46 @@ def call( setattr(FlaxOPTDecoderLayerCollection, "__call__", call) +def llama_manual_sharding(num_layers, state: TrainState): + # TODO: when rebased to jax 0.4.6, use the tree_map_with_path + param_partition = { + 'transformer': { + 'wte': {'embedding': PartitionSpec("mp", None)}, + 'ln_f': {'kernel': PartitionSpec(None)}, + 'h': { + '%d' % (layer): { + 'attention': { + # TODO: check whether we need the transpose or not + 'wq': {'kernel': PartitionSpec(None, "mp")}, + 'wk': {'kernel': PartitionSpec(None, "mp")}, + 'wv': {'kernel': PartitionSpec(None, "mp")}, + 'wo': {'kernel': PartitionSpec("mp", None)}, + }, + 'feed_forward': { + 'w1': {'kernel': PartitionSpec(None, "mp")}, + 'w2': {'kernel': PartitionSpec("mp", None)}, + 'w3': {'kernel': PartitionSpec(None, "mp")}, + }, + 'attention_norm': {'kernel': PartitionSpec(None)}, + 'ffn_norm': {'kernel': PartitionSpec(None)}, + } + for layer in range(num_layers)}, + }, + 'lm_head': {'kernel': PartitionSpec(None, "mp")}, + } + replicate = lambda x : jax.tree_util.tree_map(lambda _: PartitionSpec(None), x) + opt_state = tree_map_params(state.tx, lambda _, spec: spec, state.opt_state, + param_partition, transform_non_params=lambda _: PartitionSpec(None)) + manual_partition = TrainState(step=PartitionSpec(None), + params=param_partition, + master_copy=param_partition, + dynamic_scale=replicate(state.dynamic_scale), + tx=state.tx, + apply_fn=state.apply_fn, + opt_state=opt_state) + return manual_partition + + def main(): # See all possible arguments in src/transformers/training_args.py # or by passing the --help flag to this script. @@ -514,22 +553,6 @@ def main(): # Distributed training: # The .from_pretrained methods guarantee that only one local process can concurrently # download model & vocab. - # if model_args.config_name: - # config = AutoConfig.from_pretrained( - # model_args.config_name, - # cache_dir=model_args.cache_dir, - # use_auth_token=True if model_args.use_auth_token else None, - # ) - # elif model_args.model_name_or_path: - # config = AutoConfig.from_pretrained( - # model_args.model_name_or_path, - # cache_dir=model_args.cache_dir, - # use_auth_token=True if model_args.use_auth_token else None, - # ) - # else: - # config = CONFIG_MAPPING[model_args.model_type]() - # logger.warning("You are instantiating a new config instance from scratch.") - # TODO: merge with the above config = LLaMAConfig.load_config('test') if training_args.use_remat: @@ -556,28 +579,9 @@ def main(): "You can do it from another script, save it, and load it from here, using --tokenizer_name." ) - # if model_args.model_name_or_path: - # model = FlaxAutoModelForCausalLM.from_pretrained( - # model_args.model_name_or_path, - # config=config, - # seed=training_args.seed, - # dtype=getattr(jnp, model_args.dtype), - # use_auth_token=True if model_args.use_auth_token else None, - # ) - # #from transformers import FlaxOPTForCausalLM - # #config.num_hidden_layers = 2 - # #model = FlaxOPTForCausalLM( - # # config=config, - # # seed=training_args.seed, - # # dtype=getattr(jnp, model_args.dtype), - # #) - # else: - # model = FlaxAutoModelForCausalLM.from_config( - # config, - # seed=training_args.seed, - # dtype=getattr(jnp, model_args.dtype), - # ) - model = FlaxLLaMAForCausalLM(config, (4, 2048)) + # TODO(yonghao): don't init weight when loaded somewhere + dummy_input_shape = (4, config.max_sequence_length) + model = FlaxLLaMAForCausalLM(config, dummy_input_shape) # Preprocessing the datasets. # First we tokenize all the texts. @@ -747,6 +751,11 @@ def decay_mask_fn(params): learning_rate=linear_decay_lr_schedule_fn, ) else: + # A tmp hack for llama finetune. Remove it either: + # 1) rebase to jax 0.4 and use tree_util's mask with path for partition spec; + # 2) optax fixes the issue of symbolic exec with decay mask fn. + if training_args.weight_decay == 0.0: + decay_mask_fn = None optimizer = optax.chain( optax.clip_by_global_norm(1.0), optax.adamw( @@ -769,6 +778,11 @@ def decay_mask_fn(params): state = TrainState.create(apply_fn=model.__call__, params=model.params, tx=optimizer, dynamic_scale=dynamic_scale, use_master_copy=use_master_copy) + # Manual partition spec + state_manual_sharding = llama_manual_sharding(config.num_hidden_layers, state) + ms_option = ManualShardingOption( + ("dp", "mp"), in_axis_resources=(state_manual_sharding, PartitionSpec("dp", None))) + def loss_fn(logits, labels): shift_logits = logits[..., :-1, :] shift_labels = labels[..., 1:] @@ -828,7 +842,8 @@ def eval_step(params, batch): num_micro_batches=training_args.num_micro_batches, data_parallel=-1, operator_parallel=training_args.operator_parallel, - pipeline_parallel=training_args.pipeline_parallel) + pipeline_parallel=training_args.pipeline_parallel, + manual_sharding_option=ms_option) p_train_step = alpa.parallelize(train_step, method=method, diff --git a/examples/opt_finetune/run_llama.sh b/examples/opt_finetune/run_llama.sh index 65187f12f..faea5fd92 100644 --- a/examples/opt_finetune/run_llama.sh +++ b/examples/opt_finetune/run_llama.sh @@ -9,8 +9,8 @@ python3 run_easylm_flax.py \ --per_device_train_batch_size="32" \ --per_device_eval_batch_size="32" \ --num_micro_batches 64 \ - --operator_parallel 1 \ - --pipeline_parallel 4 \ + --operator_parallel 2 \ + --pipeline_parallel 2 \ --dtype="float16" \ --learning_rate="5e-4" --warmup_steps="2000" \ --adam_beta1="0.9" --adam_beta2="0.98" --weight_decay="0.0" \ From 07001d4d8e8ae05049be2fe93c4d4d62d98b681b Mon Sep 17 00:00:00 2001 From: ZYHowell Date: Sun, 30 Apr 2023 01:15:24 +0000 Subject: [PATCH 03/11] move to a new folder --- examples/llama_finetune/hf_jax_conversion.py | 44 +++++++++++++++++++ .../run_easylm_flax.py | 0 .../run_llama.sh | 0 3 files changed, 44 insertions(+) create mode 100644 examples/llama_finetune/hf_jax_conversion.py rename examples/{opt_finetune => llama_finetune}/run_easylm_flax.py (100%) rename examples/{opt_finetune => llama_finetune}/run_llama.sh (100%) diff --git a/examples/llama_finetune/hf_jax_conversion.py b/examples/llama_finetune/hf_jax_conversion.py new file mode 100644 index 000000000..ec39e6c7a --- /dev/null +++ b/examples/llama_finetune/hf_jax_conversion.py @@ -0,0 +1,44 @@ +import transformers +import numpy as np +import jax +import jax.numpy as jnp + +def import_hf_model(model_name_or_path): + model = transformers.AutoModelForCausalLM.from_pretrained( + model_name_or_path, + ) + return model + +def hf_to_jax(hf_model): + state_dict = hf_model.state_dict() + jax_weights = { + 'transformer': { + 'wte': {'embedding': state_dict['model.embed_tokens.weight'].numpy()}, + 'ln_f': {'kernel': state_dict['model.norm.weight'].numpy()}, + 'h': { + '%d' % (layer): { + 'attention': { + # TODO: check whether we need the transpose or not + 'wq': {'kernel': state_dict['model.layers.%d.self_attn.q_proj.weight' % (layer)].numpy().transpose()}, + 'wk': {'kernel': state_dict['model.layers.%d.self_attn.k_proj.weight' % (layer)].numpy().transpose()}, + 'wv': {'kernel': state_dict['model.layers.%d.self_attn.v_proj.weight' % (layer)].numpy().transpose()}, + 'wo': {'kernel': state_dict['model.layers.%d.self_attn.o_proj.weight' % (layer)].numpy().transpose()}, + }, + 'feed_forward': { + 'w1': {'kernel': state_dict['model.layers.%d.mlp.gate_proj.weight' % (layer)].numpy().transpose()}, + 'w2': {'kernel': state_dict['model.layers.%d.mlp.down_proj.weight' % (layer)].numpy().transpose()}, + 'w3': {'kernel': state_dict['model.layers.%d.mlp.up_proj.weight' % (layer)].numpy().transpose()}, + }, + 'attention_norm': {'kernel': state_dict['model.layers.%d.input_layernorm.weight' % (layer)].numpy()}, + 'ffn_norm': {'kernel': state_dict['model.layers.%d.post_attention_layernorm.weight' % (layer)].numpy()}, + } + for layer in range(hf_model.config.num_hidden_layers)}, + }, + 'lm_head': {'kernel': state_dict["lm_head.weight"].numpy().transpose()}, + } + return jax_weights + +if __name__ == "__main__": + hf_model = import_hf_model("./llama-7b") + jax_params = hf_to_jax(hf_model) + # EasyLM uses fout.write(flax.serialization.msgpack_serialize(jax_weights, in_place=True)) to store the param diff --git a/examples/opt_finetune/run_easylm_flax.py b/examples/llama_finetune/run_easylm_flax.py similarity index 100% rename from examples/opt_finetune/run_easylm_flax.py rename to examples/llama_finetune/run_easylm_flax.py diff --git a/examples/opt_finetune/run_llama.sh b/examples/llama_finetune/run_llama.sh similarity index 100% rename from examples/opt_finetune/run_llama.sh rename to examples/llama_finetune/run_llama.sh From 90655adeb86869edbc4ec7080df8af5ff11caf67 Mon Sep 17 00:00:00 2001 From: ZYHowell Date: Sun, 30 Apr 2023 23:41:54 +0000 Subject: [PATCH 04/11] add dataset from fastchat --- alpa/parallel_method.py | 6 +- examples/llama_finetune/hf_datasets.py | 139 ++++++++++++++++ examples/llama_finetune/run_easylm_flax.py | 181 ++++----------------- 3 files changed, 172 insertions(+), 154 deletions(-) create mode 100644 examples/llama_finetune/hf_datasets.py diff --git a/alpa/parallel_method.py b/alpa/parallel_method.py index c38a4f831..d46a841fe 100644 --- a/alpa/parallel_method.py +++ b/alpa/parallel_method.py @@ -277,16 +277,20 @@ def get_3d_parallel_method(num_micro_batches: int, assert num_mesh_devices % num_devices_per_host == 0 physical_mesh_shape = (num_mesh_devices // num_devices_per_host, num_devices_per_host) + if pipeline_parallel == num_devices: + manual_sharding_option = None # If no pipeline parallel, degenerate into shard parallel if pp == 1 and allow_degenerate_into_shard_parallel: return ShardParallel(num_micro_batches=num_micro_batches, auto_sharding_option=AutoShardingOption( + enable_auto_sharding=manual_sharding_option is None, prefer_reduce_scatter=True, force_batch_dim_to_mesh_dim=0), devices=get_global_physical_mesh( create_if_not_exist=True).get_logical_mesh( - [data_parallel, operator_parallel])) + [data_parallel, operator_parallel]), + manual_sharding_option=manual_sharding_option) # Return pipeshard parallel if manual_layer_num is not None: diff --git a/examples/llama_finetune/hf_datasets.py b/examples/llama_finetune/hf_datasets.py new file mode 100644 index 000000000..22179f512 --- /dev/null +++ b/examples/llama_finetune/hf_datasets.py @@ -0,0 +1,139 @@ +import json +from typing import Dict + +from datasets import Dataset +import numpy as np +import transformers +from transformers.trainer_pt_utils import LabelSmoother + +from fastchat.conversation import get_default_conv_template, SeparatorStyle + + +def preprocess( + sources, + tokenizer: transformers.PreTrainedTokenizer, +) -> Dict: + IGNORE_TOKEN_ID = LabelSmoother.ignore_index + conv = get_default_conv_template("vicuna").copy() + roles = {"human": conv.roles[0], "gpt": conv.roles[1]} + + # Apply prompt templates + conversations = [] + for i, source in enumerate(sources): + if roles[source[0]["from"]] != conv.roles[0]: + # Skip the first one if it is not from human + source = source[1:] + + conv.messages = [] + for j, sentence in enumerate(source): + role = roles[sentence["from"]] + assert role == conv.roles[j % 2], f"{i}" + conv.append_message(role, sentence["value"]) + conversations.append(conv.get_prompt()) + + # Tokenize conversations + input_ids = tokenizer( + conversations, + return_tensors="np", + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + ).input_ids + targets = np.copy(input_ids) + + assert conv.sep_style == SeparatorStyle.ADD_COLON_TWO + + # Mask targets + sep = conv.sep + conv.roles[1] + ": " + for conversation, target in zip(conversations, targets): + total_len = int((target != tokenizer.pad_token_id).sum()) + + rounds = conversation.split(conv.sep2) + cur_len = 1 + target[:cur_len] = IGNORE_TOKEN_ID + for i, rou in enumerate(rounds): + if rou == "": + break + + parts = rou.split(sep) + if len(parts) != 2: + break + parts[0] += sep + round_len = len(tokenizer(rou).input_ids) + instruction_len = len(tokenizer(parts[0]).input_ids) - 2 + + target[cur_len : cur_len + instruction_len] = IGNORE_TOKEN_ID + + cur_len += round_len + target[cur_len:] = IGNORE_TOKEN_ID + + if cur_len < tokenizer.model_max_length: + if cur_len != total_len: + target[:] = IGNORE_TOKEN_ID + print( + f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." + f" (ignored)" + ) + + return dict( + input_ids=input_ids, + labels=targets, + attention_mask=np.array(input_ids != tokenizer.pad_token_id), + ) + + +class LazySupervisedDataset: + """Dataset for supervised fine-tuning.""" + + def __init__(self, raw_data, tokenizer: transformers.PreTrainedTokenizer): + super(LazySupervisedDataset, self).__init__() + print("Formatting inputs...Skip in lazy mode") + self.tokenizer = tokenizer + self.raw_data = raw_data + self.cached_data_dict = {} + + def __len__(self): + return len(self.raw_data) + + def __getitem__(self, i): + if i in self.cached_data_dict: + return self.cached_data_dict[i] + + ret = preprocess([self.raw_data[i]["conversations"]], self.tokenizer) + ret = dict( + input_ids=ret["input_ids"][0], + labels=ret["labels"][0], + attention_mask=ret["attention_mask"][0], + ) + self.cached_data_dict[i] = ret + + return ret + + def iter(self): + def gen(): + for i in range(len(self)): + yield self[i] + return gen + + +def make_supervised_data_module( + tokenizer: transformers.PreTrainedTokenizer, data_path +) -> Dict: + """Make dataset and collator for supervised fine-tuning.""" + print("Loading data...") + raw_data = json.load(open(data_path, "r")) + + # Split train/test + perm = np.random.permutation(len(raw_data)) + split = int(len(perm) * 0.98) + train_indices = perm[:split] + eval_indices = perm[split:] + train_raw_data = [raw_data[i] for i in train_indices] + eval_raw_data = [raw_data[i] for i in eval_indices] + print(f"#train {len(train_raw_data)}, #eval {len(eval_raw_data)}") + + train_dataset = LazySupervisedDataset(train_raw_data, tokenizer=tokenizer) + eval_dataset = LazySupervisedDataset(eval_raw_data, tokenizer=tokenizer) + train_dataset = Dataset.from_generator(train_dataset.iter()) + eval_dataset = Dataset.from_generator(eval_dataset.iter()) + return dict(train_dataset=train_dataset, eval_dataset=eval_dataset) diff --git a/examples/llama_finetune/run_easylm_flax.py b/examples/llama_finetune/run_easylm_flax.py index d3c84b1cc..d920c38f3 100644 --- a/examples/llama_finetune/run_easylm_flax.py +++ b/examples/llama_finetune/run_easylm_flax.py @@ -1,7 +1,7 @@ # TODO: # 1. Import Llama Model Definition(done); # 2. Import Manual partition spec(done); -# 3. Import Fastchat dataset; +# 3. Import Fastchat dataset(done); # 4. Weight Conversion(done); # 5. Distributed load/store. # 6. wandb support @@ -38,13 +38,12 @@ from dataclasses import asdict, dataclass, field from enum import Enum import functools -from itertools import chain from pathlib import Path from typing import Callable, Optional import datasets import numpy as np -from datasets import Dataset, load_dataset +from datasets import Dataset from tqdm import tqdm import alpa @@ -55,7 +54,6 @@ import jax.numpy as jnp import optax import transformers -from transformers.testing_utils import CaptureLogger from transformers.utils import get_full_repo_name, send_example_telemetry import tensorflow as tf from flax import traverse_util @@ -63,7 +61,6 @@ from huggingface_hub import Repository from transformers import ( FLAX_MODEL_FOR_CAUSAL_LM_MAPPING, - AutoTokenizer, HfArgumentParser, is_tensorboard_available, set_seed, @@ -74,9 +71,11 @@ tf.config.experimental.set_visible_devices([], 'GPU') from EasyLM.models.llama.llama_model import ( - LLaMAConfig, FlaxLLaMAForCausalLMModule, FlaxLLaMAForCausalLM + LLaMAConfig, FlaxLLaMAForCausalLM ) +from hf_datasets import make_supervised_data_module + logger = logging.getLogger(__name__) MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_CAUSAL_LM_MAPPING.keys()) @@ -85,6 +84,7 @@ @dataclass class TrainingArguments: + """A subset of Huggingface's training arguments""" output_dir: str = field( metadata={"help": "The output directory where the model predictions and checkpoints will be written."}, ) @@ -475,76 +475,6 @@ def main(): repo_name = training_args.hub_model_id repo = Repository(training_args.output_dir, clone_from=repo_name) - # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below) - # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/ - # (the dataset will be downloaded automatically from the datasets Hub). - # - # For CSV/JSON files, this script will use the column called 'text' or the first column if no column called - # 'text' is found. You can easily tweak this behavior (see below). - # - # In distributed training, the load_dataset function guarantees that only one local process can concurrently - # download the dataset. - if data_args.dataset_name is not None: - # Downloading and loading a dataset from the hub. - dataset = load_dataset( - data_args.dataset_name, - data_args.dataset_config_name, - cache_dir=model_args.cache_dir, - keep_in_memory=False, - use_auth_token=True if model_args.use_auth_token else None, - ) - - if "validation" not in dataset.keys(): - dataset["validation"] = load_dataset( - data_args.dataset_name, - data_args.dataset_config_name, - split=f"train[:{data_args.validation_split_percentage}%]", - cache_dir=model_args.cache_dir, - use_auth_token=True if model_args.use_auth_token else None, - ) - dataset["train"] = load_dataset( - data_args.dataset_name, - data_args.dataset_config_name, - split=f"train[{data_args.validation_split_percentage}%:]", - cache_dir=model_args.cache_dir, - use_auth_token=True if model_args.use_auth_token else None, - ) - else: - data_files = {} - dataset_args = {} - if data_args.train_file is not None: - data_files["train"] = data_args.train_file - if data_args.validation_file is not None: - data_files["validation"] = data_args.validation_file - extension = data_args.train_file.split(".")[-1] - if extension == "txt": - extension = "text" - dataset_args["keep_linebreaks"] = data_args.keep_linebreaks - dataset = load_dataset( - extension, - data_files=data_files, - cache_dir=model_args.cache_dir, - **dataset_args, - use_auth_token=True if model_args.use_auth_token else None, - ) - - if "validation" not in dataset.keys(): - dataset["validation"] = load_dataset( - extension, - data_files=data_files, - split=f"train[:{data_args.validation_split_percentage}%]", - cache_dir=model_args.cache_dir, - **dataset_args, - use_auth_token=True if model_args.use_auth_token else None, - ) - dataset["train"] = load_dataset( - extension, - data_files=data_files, - split=f"train[{data_args.validation_split_percentage}%:]", - cache_dir=model_args.cache_dir, - **dataset_args, - use_auth_token=True if model_args.use_auth_token else None, - ) # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at # https://huggingface.co/docs/datasets/loading_datasets.html. @@ -558,61 +488,29 @@ def main(): if training_args.use_remat: monkey_patch_remat() - if model_args.tokenizer_name: - tokenizer = AutoTokenizer.from_pretrained( - model_args.tokenizer_name, - cache_dir=model_args.cache_dir, - use_fast=model_args.use_fast_tokenizer, - use_auth_token=True if model_args.use_auth_token else None, - ) - elif model_args.model_name_or_path: - tokenizer = AutoTokenizer.from_pretrained( - model_args.model_name_or_path, - cache_dir=model_args.cache_dir, - #use_fast=model_args.use_fast_tokenizer, - use_auth_token=True if model_args.use_auth_token else None, - use_fast=False, - ) - else: - raise ValueError( - "You are instantiating a new tokenizer from scratch. This is not supported by this script." - "You can do it from another script, save it, and load it from here, using --tokenizer_name." - ) + tokenizer = transformers.AutoTokenizer.from_pretrained( + model_args.model_name_or_path, + model_max_length=config.max_sequence_length, + padding_side="right", + use_fast=False, + ) + tokenizer.pad_token = tokenizer.unk_token # TODO(yonghao): don't init weight when loaded somewhere dummy_input_shape = (4, config.max_sequence_length) model = FlaxLLaMAForCausalLM(config, dummy_input_shape) - # Preprocessing the datasets. - # First we tokenize all the texts. - if training_args.do_train: - column_names = dataset["train"].column_names - else: - column_names = dataset["validation"].column_names - text_column_name = "text" if "text" in column_names else column_names[0] - - # since this will be pickled to avoid _LazyModule error in Hasher force logger loading before tokenize_function - tok_logger = transformers.utils.logging.get_logger("transformers.tokenization_utils_base") - - def tokenize_function(examples): - with CaptureLogger(tok_logger) as cl: - output = tokenizer(examples[text_column_name]) - # clm input could be much much longer than block_size - if "Token indices sequence length is longer than the" in cl.out: - tok_logger.warning( - "^^^^^^^^^^^^^^^^ Please ignore the warning above - this long input will be chunked into smaller bits" - " before being passed to the model." - ) - return output - - logger.info("***** Tokenize dataset *****") - tokenized_datasets = dataset.map( - tokenize_function, - batched=True, - num_proc=data_args.preprocessing_num_workers, - remove_columns=column_names, - load_from_cache_file=not data_args.overwrite_cache, - ) + # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below) + # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/ + # (the dataset will be downloaded automatically from the datasets Hub). + # + # For CSV/JSON files, this script will use the column called 'text' or the first column if no column called + # 'text' is found. You can easily tweak this behavior (see below). + # + # In distributed training, the load_dataset function guarantees that only one local process can concurrently + # download the dataset. + data_module = make_supervised_data_module(tokenizer, data_args.dataset_name) + if data_args.block_size is None: block_size = tokenizer.model_max_length @@ -630,23 +528,6 @@ def tokenize_function(examples): ) block_size = min(data_args.block_size, tokenizer.model_max_length) - # Main data processing function that will concatenate all texts from our dataset and generate chunks of block_size. - def group_texts(examples): - # Concatenate all texts. - concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()} - total_length = len(concatenated_examples[list(examples.keys())[0]]) - # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can - # customize this part to your needs. - if total_length >= block_size: - total_length = (total_length // block_size) * block_size - # Split by chunks of max_len. - result = { - k: [t[i : i + block_size] for i in range(0, total_length, block_size)] - for k, t in concatenated_examples.items() - } - result["labels"] = result["input_ids"].copy() - return result - # Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a remainder # for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value might be slower # to preprocess. @@ -655,25 +536,19 @@ def group_texts(examples): # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map logger.info("***** Build dataset *****") - lm_datasets = tokenized_datasets.map( - group_texts, - batched=True, - num_proc=data_args.preprocessing_num_workers, - load_from_cache_file=not data_args.overwrite_cache, - ) if training_args.do_train: - if "train" not in tokenized_datasets: + if "train_dataset" not in data_module: raise ValueError("--do_train requires a train dataset") - train_dataset = lm_datasets["train"] + train_dataset = data_module["train_dataset"] if data_args.max_train_samples is not None: max_train_samples = min(len(train_dataset), data_args.max_train_samples) train_dataset = train_dataset.select(range(max_train_samples)) if training_args.do_eval: - if "validation" not in tokenized_datasets: + if "eval_dataset" not in data_module: raise ValueError("--do_eval requires a validation dataset") - eval_dataset = lm_datasets["validation"] + eval_dataset = data_module["eval_dataset"] if data_args.max_eval_samples is not None: max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples) eval_dataset = eval_dataset.select(range(max_eval_samples)) From 7a9c5fa61471975ef218fae94c1d31f1b03bfb4d Mon Sep 17 00:00:00 2001 From: ZYHowell Date: Mon, 1 May 2023 07:07:53 +0400 Subject: [PATCH 05/11] final workable combination --- alpa/util.py | 16 ++-- examples/llama_finetune/hf_jax_conversion.py | 5 +- examples/llama_finetune/run_easylm_flax.py | 98 +++++++------------- examples/llama_finetune/run_llama.sh | 29 +++--- 4 files changed, 55 insertions(+), 93 deletions(-) diff --git a/alpa/util.py b/alpa/util.py index 29649e565..0a1222a9e 100644 --- a/alpa/util.py +++ b/alpa/util.py @@ -1663,20 +1663,22 @@ def compute_gpt_tflops(batch_size, num_gpus, latency, backward=True, - checkpoint_activations=False): + checkpoint_activations=False, + intermediate_size=None): """ Compute the Tera Flop Operations (TFLOP) per second per GPU for GPT-like models. """ - factor = 24 + factor = 2 if backward: - factor += 48 + factor += 4 if checkpoint_activations: - factor += 24 + factor += 2 + if intermediate_size is None: + intermediate_size = hidden_size * 4 - total_flop = (factor * batch_size * seq_len * - (hidden_size**2) * num_layers * (1 + seq_len / - (6 * hidden_size)) + + total_flop = ((factor * num_layers * batch_size * seq_len * hidden_size * + (4 * hidden_size + 2 * intermediate_size + 2 * seq_len)) + 6 * batch_size * seq_len * hidden_size * vocab_size) # Note: The above formula does not count the first embedding table lookup # because it is a sparse operation. diff --git a/examples/llama_finetune/hf_jax_conversion.py b/examples/llama_finetune/hf_jax_conversion.py index ec39e6c7a..d1f5daf4c 100644 --- a/examples/llama_finetune/hf_jax_conversion.py +++ b/examples/llama_finetune/hf_jax_conversion.py @@ -1,7 +1,4 @@ import transformers -import numpy as np -import jax -import jax.numpy as jnp def import_hf_model(model_name_or_path): model = transformers.AutoModelForCausalLM.from_pretrained( @@ -9,7 +6,7 @@ def import_hf_model(model_name_or_path): ) return model -def hf_to_jax(hf_model): +def hf_to_jax_weight(hf_model): state_dict = hf_model.state_dict() jax_weights = { 'transformer': { diff --git a/examples/llama_finetune/run_easylm_flax.py b/examples/llama_finetune/run_easylm_flax.py index d920c38f3..9c9c11409 100644 --- a/examples/llama_finetune/run_easylm_flax.py +++ b/examples/llama_finetune/run_easylm_flax.py @@ -75,6 +75,7 @@ ) from hf_datasets import make_supervised_data_module +from hf_jax_conversion import hf_to_jax_weight logger = logging.getLogger(__name__) @@ -116,7 +117,7 @@ class TrainingArguments: adam_epsilon: float = field(default=1e-8, metadata={"help": "Epsilon for AdamW optimizer."}) adafactor: bool = field(default=False, metadata={"help": "Whether or not to replace AdamW by Adafactor."}) num_train_epochs: float = field(default=3.0, metadata={"help": "Total number of training epochs to perform."}) - warmup_steps: int = field(default=0, metadata={"help": "Linear warmup over warmup_steps."}) + warmup_ratio: float = field(default=0.0, metadata={"help": "Linear warmup over a ratio of overall steps."}) logging_steps: int = field(default=500, metadata={"help": "Log every X updates steps."}) save_steps: int = field(default=500, metadata={"help": "Save checkpoint every X updates steps."}) eval_steps: int = field(default=None, metadata={"help": "Run an evaluation every X steps."}) @@ -315,71 +316,20 @@ def write_eval_metric(summary_writer, eval_metrics, step): def create_learning_rate_fn( - train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float + train_ds_size: int, train_batch_size: int, num_train_epochs: int, warmup_ratio: float, learning_rate: float ) -> Callable[[int], jnp.array]: """Returns a linear warmup, linear_decay learning rate function.""" steps_per_epoch = train_ds_size // train_batch_size num_train_steps = steps_per_epoch * num_train_epochs + num_warmup_steps = int(num_train_steps * warmup_ratio) warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps) - decay_fn = optax.linear_schedule( - init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps + decay_fn = optax.cosine_decay_schedule( + init_value=learning_rate, decay_steps=num_train_steps - num_warmup_steps ) schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps]) return schedule_fn -def monkey_patch_remat(): - # Use monkey patch to add remat for all transformer layers. - from transformers.models.opt.modeling_flax_opt import FlaxOPTDecoderLayer, FlaxOPTDecoderLayerCollection - from flax.linen.partitioning import remat - from flax.linen.module import wrap_method_once - import flax.linen as nn - - @wrap_method_once - def setup(self): - self.layers = [ - remat(FlaxOPTDecoderLayer, static_argnums=(2, 3, 4))( - self.config, name=str(i), dtype=self.dtype) - for i in range(self.config.num_hidden_layers) - ] - self.layerdrop = self.config.layerdrop - - def call( - self, - hidden_states, - attention_mask, - deterministic: bool = True, - init_cache: bool = False, - output_attentions: bool = False, - output_hidden_states: bool = False, - ): - # decoder layers - all_hidden_states = () if output_hidden_states else None - all_self_attns = () if output_attentions else None - - for decoder_layer in self.layers: - if output_hidden_states: - all_hidden_states += (hidden_states,) - - layer_outputs = decoder_layer( - hidden_states, - attention_mask, - init_cache, - output_attentions, - deterministic, - ) - - hidden_states = layer_outputs[0] - if output_attentions: - all_self_attns += (layer_outputs[1],) - - outputs = [hidden_states, all_hidden_states, all_self_attns] - return outputs - - setattr(FlaxOPTDecoderLayerCollection, "setup", setup) - setattr(FlaxOPTDecoderLayerCollection, "__call__", call) - - def llama_manual_sharding(num_layers, state: TrainState): # TODO: when rebased to jax 0.4.6, use the tree_map_with_path param_partition = { @@ -483,10 +433,15 @@ def main(): # Distributed training: # The .from_pretrained methods guarantee that only one local process can concurrently # download model & vocab. - config = LLaMAConfig.load_config('test') - - if training_args.use_remat: - monkey_patch_remat() + config = LLaMAConfig.load_config('7b') + if model_args.dtype == "float16": + dtype = jnp.float16 + elif model_args.dtype == "float32": + dtype = jnp.float32 + elif model_args.dtype == "bfloat16": + dtype = jnp.bfloat16 + else: + raise ValueError(f"{model_args.dtype} unsupported") tokenizer = transformers.AutoTokenizer.from_pretrained( model_args.model_name_or_path, @@ -498,7 +453,13 @@ def main(): # TODO(yonghao): don't init weight when loaded somewhere dummy_input_shape = (4, config.max_sequence_length) - model = FlaxLLaMAForCausalLM(config, dummy_input_shape) + # Monkey patch the model's init to init_dummy + model = FlaxLLaMAForCausalLM(config, dummy_input_shape, dtype=dtype) + hf_model = transformers.AutoModelForCausalLM.from_pretrained( + model_args.model_name_or_path, + ) + params = hf_to_jax_weight(hf_model) + del hf_model # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below) # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/ @@ -595,11 +556,11 @@ def main(): total_train_steps = steps_per_epoch * num_epochs # Create learning rate schedule - linear_decay_lr_schedule_fn = create_learning_rate_fn( + cosine_decay_lr_schedule_fn = create_learning_rate_fn( len(train_dataset), train_batch_size, training_args.num_train_epochs, - training_args.warmup_steps, + training_args.warmup_ratio, training_args.learning_rate, ) @@ -623,7 +584,7 @@ def decay_mask_fn(params): # We use the default parameters here to initialize adafactor, # For more details about the parameters please check https://github.com/deepmind/optax/blob/ed02befef9bf81cbbf236be3d2b0e032e9ed4a40/optax/_src/alias.py#L74 optimizer = optax.adafactor( - learning_rate=linear_decay_lr_schedule_fn, + learning_rate=cosine_decay_lr_schedule_fn, ) else: # A tmp hack for llama finetune. Remove it either: @@ -634,7 +595,7 @@ def decay_mask_fn(params): optimizer = optax.chain( optax.clip_by_global_norm(1.0), optax.adamw( - learning_rate=linear_decay_lr_schedule_fn, + learning_rate=cosine_decay_lr_schedule_fn, b1=training_args.adam_beta1, b2=training_args.adam_beta2, eps=training_args.adam_epsilon, @@ -698,7 +659,7 @@ def compute_loss(params): new_state.master_copy, state.master_copy), dynamic_scale=dynamic_scale) - metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)} + metrics = {"loss": loss, "learning_rate": cosine_decay_lr_schedule_fn(state.step)} return new_state, metrics @@ -718,6 +679,7 @@ def eval_step(params, batch): data_parallel=-1, operator_parallel=training_args.operator_parallel, pipeline_parallel=training_args.pipeline_parallel, + manual_layer_num=config.num_hidden_layers, manual_sharding_option=ms_option) p_train_step = alpa.parallelize(train_step, @@ -786,7 +748,9 @@ def eval_step(params, batch): hidden_size=config.hidden_size, vocab_size=config.vocab_size, num_gpus=alpa.get_global_num_devices(), - latency=latency) + latency=latency, + checkpoint_activations=True, + intermediate_size=config.intermediate_size) step_ct = 0 # Save metrics diff --git a/examples/llama_finetune/run_llama.sh b/examples/llama_finetune/run_llama.sh index faea5fd92..af909b4a1 100644 --- a/examples/llama_finetune/run_llama.sh +++ b/examples/llama_finetune/run_llama.sh @@ -1,21 +1,20 @@ -export PYTHONPATH=$HOME/alpa-proj/EasyLM:$PYTHONPATH +export PYTHONPATH=$HOME/alpa/EasyLM:$PYTHONPATH python3 run_easylm_flax.py \ --output_dir="./output" \ - --model_name_or_path="$HOME/alpa-proj/llama-7b" \ - --dataset_name="wikitext" \ - --dataset_config_name="wikitext-2-raw-v1" \ - --do_train --do_eval \ + --model_name_or_path="/data/llama-7b" \ + --dataset_name="/data/sharegpt.json" \ + --do_train \ --block_size="1024" \ --per_device_train_batch_size="32" \ - --per_device_eval_batch_size="32" \ - --num_micro_batches 64 \ - --operator_parallel 2 \ - --pipeline_parallel 2 \ + --per_device_eval_batch_size="16" \ + --num_micro_batches 32 \ + --operator_parallel 1 \ + --pipeline_parallel 8 \ --dtype="float16" \ - --learning_rate="5e-4" --warmup_steps="2000" \ - --adam_beta1="0.9" --adam_beta2="0.98" --weight_decay="0.0" \ + --learning_rate="5e-4" --warmup_ratio="0.03" \ + --weight_decay="0.0" \ --overwrite_output_dir \ - --num_train_epochs="10" \ - --logging_steps="5" \ - --save_steps="40" \ - --eval_steps="25" + --num_train_epochs="3" \ + --logging_steps="1" \ + --save_steps="3000" \ + --eval_steps="1000" From 36c9f4b07f7c065e7e0e5fbe252a46c8000f648a Mon Sep 17 00:00:00 2001 From: ZYHowell Date: Mon, 1 May 2023 07:48:10 +0400 Subject: [PATCH 06/11] add readme and minor fix script --- examples/llama_finetune/README.md | 61 ++++++++++++++++++++++ examples/llama_finetune/monkey_patch.py | 17 ++++++ examples/llama_finetune/run_easylm_flax.py | 15 ++++-- examples/llama_finetune/run_llama.sh | 2 +- 4 files changed, 89 insertions(+), 6 deletions(-) create mode 100644 examples/llama_finetune/README.md create mode 100644 examples/llama_finetune/monkey_patch.py diff --git a/examples/llama_finetune/README.md b/examples/llama_finetune/README.md new file mode 100644 index 000000000..2e75f2b74 --- /dev/null +++ b/examples/llama_finetune/README.md @@ -0,0 +1,61 @@ +This script needs some monkey-patches on the original EasyLM's model definition: + +##### Fix Import Errors + +EasyLM is based on jax 0.4, while this branch is tested on jax 0.3.22. Some import errors needs to be fixed: + +``` +--- a/EasyLM/jax_utils.py ++++ b/EasyLM/jax_utils.py +@@ -10,8 +10,8 @@ import dill + import flax + import jax + import jax.numpy as jnp +-from jax.sharding import PartitionSpec as PS +-from jax.sharding import Mesh ++from jax.experimental.pjit import PartitionSpec as PS ++from jax.interpreters.pxla import Mesh + from jax.experimental.pjit import with_sharding_constraint as _with_sharding_constraint + from jax.experimental.pjit import pjit + from jax.interpreters import pxla +``` + +``` +--- a/EasyLM/models/llama/llama_model.py ++++ b/EasyLM/models/llama/llama_model.py +@@ -8,7 +8,7 @@ import numpy as np + import jax + import jax.numpy as jnp + from jax import lax +-from jax.sharding import PartitionSpec as PS ++from jax.experimental.pjit import PartitionSpec as PS + import flax.linen as nn + from flax.core.frozen_dict import FrozenDict, freeze, unfreeze + from flax.linen import combine_masks, make_causal_mask +``` + +##### Support mark pipeline boundary +We use manual pipeline boundary, though the auto one works in most cases. So we add a marker at the end of each layer. + +Will monkey patch it in the training script later. + +``` +--- a/EasyLM/models/llama/llama_model.py ++++ b/EasyLM/models/llama/llama_model.py +@@ -31,6 +31,7 @@ from mlxu import function_args_to_config, load_pickle, open_file + from EasyLM.jax_utils import ( + with_sharding_constraint, get_jax_mesh, get_gradient_checkpoint_policy + ) ++from alpa import mark_pipeline_boundary + + + LLAMA_STANDARD_CONFIGS = { +@@ -829,6 +830,7 @@ class FlaxLLaMABlockCollection(nn.Module): + output_attentions, + fcm_mask, + ) ++ mark_pipeline_boundary() + hidden_states = layer_outputs[0] + + if output_attentions: +``` diff --git a/examples/llama_finetune/monkey_patch.py b/examples/llama_finetune/monkey_patch.py new file mode 100644 index 000000000..420667a47 --- /dev/null +++ b/examples/llama_finetune/monkey_patch.py @@ -0,0 +1,17 @@ +from functools import partial + +import jax +import jax.numpy as jnp + +from EasyLM.models.llama.llama_model import FlaxLLaMAForCausalLMModule + + +def do_monkey_patch(): + # TODO: jax 0.3.22 does not support eval shape with static args well. Remove + # after rebasing to jax 0.4, use the model's _do_init=False then. + def init_dummy(self, *args, **kwargs): + avals = jax.eval_shape(partial(self._backup_init, **kwargs), *args) + return jax.tree_util.tree_map(lambda x: jnp.full(x.shape, 1e-8, x.dtype), + avals) + FlaxLLaMAForCausalLMModule._backup_init = FlaxLLaMAForCausalLMModule.init + FlaxLLaMAForCausalLMModule.init = init_dummy \ No newline at end of file diff --git a/examples/llama_finetune/run_easylm_flax.py b/examples/llama_finetune/run_easylm_flax.py index 9c9c11409..c9f18fddc 100644 --- a/examples/llama_finetune/run_easylm_flax.py +++ b/examples/llama_finetune/run_easylm_flax.py @@ -76,6 +76,7 @@ from hf_datasets import make_supervised_data_module from hf_jax_conversion import hf_to_jax_weight +from monkey_patch import do_monkey_patch logger = logging.getLogger(__name__) @@ -442,6 +443,7 @@ def main(): dtype = jnp.bfloat16 else: raise ValueError(f"{model_args.dtype} unsupported") + config.gradient_checkpointing = training_args.use_remat tokenizer = transformers.AutoTokenizer.from_pretrained( model_args.model_name_or_path, @@ -454,11 +456,12 @@ def main(): # TODO(yonghao): don't init weight when loaded somewhere dummy_input_shape = (4, config.max_sequence_length) # Monkey patch the model's init to init_dummy + do_monkey_patch() model = FlaxLLaMAForCausalLM(config, dummy_input_shape, dtype=dtype) hf_model = transformers.AutoModelForCausalLM.from_pretrained( model_args.model_name_or_path, ) - params = hf_to_jax_weight(hf_model) + loaded_params = hf_to_jax_weight(hf_model) del hf_model # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below) @@ -521,7 +524,7 @@ def main(): eval_num_micro_batches = training_args.num_micro_batches eval_min_batch_size = (num_devices // training_args.operator_parallel // training_args.pipeline_parallel * eval_num_micro_batches) - while len(eval_dataset) < eval_min_batch_size: + while training_args.do_eval and (len(eval_dataset) < eval_min_batch_size): eval_num_micro_batches //= 2 eval_min_batch_size = (num_devices // training_args.operator_parallel // training_args.pipeline_parallel * eval_num_micro_batches) @@ -611,7 +614,7 @@ def decay_mask_fn(params): alpa.global_config.flax_always_use_fp16_embedding = True else: use_master_copy = dynamic_scale = None - state = TrainState.create(apply_fn=model.__call__, params=model.params, tx=optimizer, + state = TrainState.create(apply_fn=model.__call__, params=loaded_params, tx=optimizer, dynamic_scale=dynamic_scale, use_master_copy=use_master_copy) # Manual partition spec @@ -632,7 +635,9 @@ def train_step(state, batch): def compute_loss(params): labels = batch.pop("labels") - logits = state.apply_fn(**batch, params=params, deterministic=True)[0] + # Currently we don't support non-deterministic training with remat, + # so train=False. This arg has no other impact. + logits = state.apply_fn(**batch, params=params, train=False)[0] loss = loss_fn(logits, labels) return loss @@ -771,7 +776,7 @@ def eval_step(params, batch): train_metrics = [] last_time = time.time() - if cur_step % training_args.eval_steps == 0 and cur_step > 0: + if training_args.do_eval and cur_step % training_args.eval_steps == 0 and cur_step > 0: # ======================== Evaluating ============================== eval_metrics = [] eval_loader = data_loader(input_rng, eval_dataset, eval_batch_size, diff --git a/examples/llama_finetune/run_llama.sh b/examples/llama_finetune/run_llama.sh index af909b4a1..0fdf81135 100644 --- a/examples/llama_finetune/run_llama.sh +++ b/examples/llama_finetune/run_llama.sh @@ -1,4 +1,4 @@ -export PYTHONPATH=$HOME/alpa/EasyLM:$PYTHONPATH +export PYTHONPATH=$HOME/EasyLM:$PYTHONPATH python3 run_easylm_flax.py \ --output_dir="./output" \ --model_name_or_path="/data/llama-7b" \ From 896d3db0d254032d8b4ecca78fade8c89ea1a184 Mon Sep 17 00:00:00 2001 From: ZYHowell Date: Mon, 1 May 2023 07:51:53 +0400 Subject: [PATCH 07/11] format --- alpa/parallel_method.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/alpa/parallel_method.py b/alpa/parallel_method.py index d46a841fe..314b78706 100644 --- a/alpa/parallel_method.py +++ b/alpa/parallel_method.py @@ -282,15 +282,16 @@ def get_3d_parallel_method(num_micro_batches: int, # If no pipeline parallel, degenerate into shard parallel if pp == 1 and allow_degenerate_into_shard_parallel: - return ShardParallel(num_micro_batches=num_micro_batches, - auto_sharding_option=AutoShardingOption( - enable_auto_sharding=manual_sharding_option is None, - prefer_reduce_scatter=True, - force_batch_dim_to_mesh_dim=0), - devices=get_global_physical_mesh( - create_if_not_exist=True).get_logical_mesh( - [data_parallel, operator_parallel]), - manual_sharding_option=manual_sharding_option) + return ShardParallel( + num_micro_batches=num_micro_batches, + auto_sharding_option=AutoShardingOption( + enable_auto_sharding=manual_sharding_option is None, + prefer_reduce_scatter=True, + force_batch_dim_to_mesh_dim=0), + devices=get_global_physical_mesh( + create_if_not_exist=True).get_logical_mesh( + [data_parallel, operator_parallel]), + manual_sharding_option=manual_sharding_option) # Return pipeshard parallel if manual_layer_num is not None: From 5ccc35570185d5f2f274d68bf3c61d0f54b845d8 Mon Sep 17 00:00:00 2001 From: ZYHowell Date: Mon, 1 May 2023 08:15:10 +0400 Subject: [PATCH 08/11] minor fix --- examples/llama_finetune/run_easylm_flax.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/llama_finetune/run_easylm_flax.py b/examples/llama_finetune/run_easylm_flax.py index c9f18fddc..f39992529 100644 --- a/examples/llama_finetune/run_easylm_flax.py +++ b/examples/llama_finetune/run_easylm_flax.py @@ -443,7 +443,7 @@ def main(): dtype = jnp.bfloat16 else: raise ValueError(f"{model_args.dtype} unsupported") - config.gradient_checkpointing = training_args.use_remat + # TODO: set the correct remat policy. tokenizer = transformers.AutoTokenizer.from_pretrained( model_args.model_name_or_path, From a4117b2dfc84b43e837055fde70791ae46378bb6 Mon Sep 17 00:00:00 2001 From: ZYHowell Date: Thu, 4 May 2023 03:16:00 +0400 Subject: [PATCH 09/11] add ignore index in dataset and fix loss --- examples/llama_finetune/hf_datasets.py | 43 +++++++------- examples/llama_finetune/hf_jax_conversion.py | 1 - examples/llama_finetune/run_easylm_flax.py | 60 +++++++++++++------- 3 files changed, 63 insertions(+), 41 deletions(-) diff --git a/examples/llama_finetune/hf_datasets.py b/examples/llama_finetune/hf_datasets.py index 22179f512..86d4aea4c 100644 --- a/examples/llama_finetune/hf_datasets.py +++ b/examples/llama_finetune/hf_datasets.py @@ -4,16 +4,12 @@ from datasets import Dataset import numpy as np import transformers -from transformers.trainer_pt_utils import LabelSmoother from fastchat.conversation import get_default_conv_template, SeparatorStyle -def preprocess( - sources, - tokenizer: transformers.PreTrainedTokenizer, -) -> Dict: - IGNORE_TOKEN_ID = LabelSmoother.ignore_index +def preprocess(sources, tokenizer: transformers.PreTrainedTokenizer, + ignore_token_id) -> Dict: conv = get_default_conv_template("vicuna").copy() roles = {"human": conv.roles[0], "gpt": conv.roles[1]} @@ -50,7 +46,7 @@ def preprocess( rounds = conversation.split(conv.sep2) cur_len = 1 - target[:cur_len] = IGNORE_TOKEN_ID + target[:cur_len] = ignore_token_id for i, rou in enumerate(rounds): if rou == "": break @@ -62,18 +58,17 @@ def preprocess( round_len = len(tokenizer(rou).input_ids) instruction_len = len(tokenizer(parts[0]).input_ids) - 2 - target[cur_len : cur_len + instruction_len] = IGNORE_TOKEN_ID + target[cur_len:cur_len + instruction_len] = ignore_token_id cur_len += round_len - target[cur_len:] = IGNORE_TOKEN_ID + target[cur_len:] = ignore_token_id if cur_len < tokenizer.model_max_length: if cur_len != total_len: - target[:] = IGNORE_TOKEN_ID + target[:] = ignore_token_id print( f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." - f" (ignored)" - ) + f" (ignored)") return dict( input_ids=input_ids, @@ -85,12 +80,14 @@ def preprocess( class LazySupervisedDataset: """Dataset for supervised fine-tuning.""" - def __init__(self, raw_data, tokenizer: transformers.PreTrainedTokenizer): + def __init__(self, raw_data, tokenizer: transformers.PreTrainedTokenizer, + ignore_token_id): super(LazySupervisedDataset, self).__init__() print("Formatting inputs...Skip in lazy mode") self.tokenizer = tokenizer self.raw_data = raw_data self.cached_data_dict = {} + self.ignore_token_id = ignore_token_id def __len__(self): return len(self.raw_data) @@ -99,7 +96,8 @@ def __getitem__(self, i): if i in self.cached_data_dict: return self.cached_data_dict[i] - ret = preprocess([self.raw_data[i]["conversations"]], self.tokenizer) + ret = preprocess([self.raw_data[i]["conversations"]], self.tokenizer, + self.ignore_token_id) ret = dict( input_ids=ret["input_ids"][0], labels=ret["labels"][0], @@ -108,17 +106,18 @@ def __getitem__(self, i): self.cached_data_dict[i] = ret return ret - + def iter(self): + def gen(): for i in range(len(self)): yield self[i] + return gen -def make_supervised_data_module( - tokenizer: transformers.PreTrainedTokenizer, data_path -) -> Dict: +def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer, + data_path, ignore_token_id) -> Dict: """Make dataset and collator for supervised fine-tuning.""" print("Loading data...") raw_data = json.load(open(data_path, "r")) @@ -132,8 +131,12 @@ def make_supervised_data_module( eval_raw_data = [raw_data[i] for i in eval_indices] print(f"#train {len(train_raw_data)}, #eval {len(eval_raw_data)}") - train_dataset = LazySupervisedDataset(train_raw_data, tokenizer=tokenizer) - eval_dataset = LazySupervisedDataset(eval_raw_data, tokenizer=tokenizer) + train_dataset = LazySupervisedDataset(train_raw_data, + tokenizer=tokenizer, + ignore_token_id=ignore_token_id) + eval_dataset = LazySupervisedDataset(eval_raw_data, + tokenizer=tokenizer, + ignore_token_id=ignore_token_id) train_dataset = Dataset.from_generator(train_dataset.iter()) eval_dataset = Dataset.from_generator(eval_dataset.iter()) return dict(train_dataset=train_dataset, eval_dataset=eval_dataset) diff --git a/examples/llama_finetune/hf_jax_conversion.py b/examples/llama_finetune/hf_jax_conversion.py index d1f5daf4c..0464bb037 100644 --- a/examples/llama_finetune/hf_jax_conversion.py +++ b/examples/llama_finetune/hf_jax_conversion.py @@ -15,7 +15,6 @@ def hf_to_jax_weight(hf_model): 'h': { '%d' % (layer): { 'attention': { - # TODO: check whether we need the transpose or not 'wq': {'kernel': state_dict['model.layers.%d.self_attn.q_proj.weight' % (layer)].numpy().transpose()}, 'wk': {'kernel': state_dict['model.layers.%d.self_attn.k_proj.weight' % (layer)].numpy().transpose()}, 'wv': {'kernel': state_dict['model.layers.%d.self_attn.v_proj.weight' % (layer)].numpy().transpose()}, diff --git a/examples/llama_finetune/run_easylm_flax.py b/examples/llama_finetune/run_easylm_flax.py index f39992529..b821705a4 100644 --- a/examples/llama_finetune/run_easylm_flax.py +++ b/examples/llama_finetune/run_easylm_flax.py @@ -82,6 +82,7 @@ MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_CAUSAL_LM_MAPPING.keys()) MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES) +IGNORE_TOKEN_ID = -100 @dataclass @@ -361,16 +362,39 @@ def llama_manual_sharding(num_layers, state: TrainState): replicate = lambda x : jax.tree_util.tree_map(lambda _: PartitionSpec(None), x) opt_state = tree_map_params(state.tx, lambda _, spec: spec, state.opt_state, param_partition, transform_non_params=lambda _: PartitionSpec(None)) - manual_partition = TrainState(step=PartitionSpec(None), - params=param_partition, - master_copy=param_partition, - dynamic_scale=replicate(state.dynamic_scale), - tx=state.tx, - apply_fn=state.apply_fn, - opt_state=opt_state) + manual_partition = TrainState( + step=PartitionSpec(None), + params=param_partition, + master_copy=param_partition if state.master_copy else None, + dynamic_scale=replicate(state.dynamic_scale), + tx=state.tx, + apply_fn=state.apply_fn, + opt_state=opt_state) return manual_partition +# TODO: smoothing factor +def loss_fn(logits, labels, ignore_indices): + # Shift logits + shift_logits = logits[..., :-1, :] + shift_labels = labels[..., 1:] + # Handle the ignore index: compute the valid first + valid = jnp.full(shift_labels.shape, True) + for ignore_index in ignore_indices: + new_valid = jnp.not_equal(shift_labels, ignore_index) + valid = jnp.logical_and(valid, new_valid) + valid_len = jnp.maximum(jnp.sum(valid, axis=-1), 1e-10) + # OneHot and mask the ignore index. For ignore_index(-100), the whole line + # in the output would be 0. + one_hot_labels = jax.nn.one_hot(shift_labels, shift_logits.shape[-1]) + # Compute the softmax loss + log_p = jax.nn.log_softmax(shift_logits, axis=-1) + # (bs, seq_len, vocab) -> (bs, seq_len) + cross_entropy = jnp.sum(one_hot_labels * log_p, axis=-1) + loss = -jnp.mean(jnp.sum(cross_entropy, axis=-1) / valid_len) + return loss + + def main(): # See all possible arguments in src/transformers/training_args.py # or by passing the --help flag to this script. @@ -452,6 +476,10 @@ def main(): use_fast=False, ) tokenizer.pad_token = tokenizer.unk_token + config.update(dict( + bos_token_id=tokenizer.bos_token_id, + eos_token_id=tokenizer.eos_token_id, + )) # TODO(yonghao): don't init weight when loaded somewhere dummy_input_shape = (4, config.max_sequence_length) @@ -462,7 +490,6 @@ def main(): model_args.model_name_or_path, ) loaded_params = hf_to_jax_weight(hf_model) - del hf_model # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below) # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/ @@ -473,7 +500,7 @@ def main(): # # In distributed training, the load_dataset function guarantees that only one local process can concurrently # download the dataset. - data_module = make_supervised_data_module(tokenizer, data_args.dataset_name) + data_module = make_supervised_data_module(tokenizer, data_args.dataset_name, IGNORE_TOKEN_ID) if data_args.block_size is None: @@ -621,14 +648,7 @@ def decay_mask_fn(params): state_manual_sharding = llama_manual_sharding(config.num_hidden_layers, state) ms_option = ManualShardingOption( ("dp", "mp"), in_axis_resources=(state_manual_sharding, PartitionSpec("dp", None))) - - def loss_fn(logits, labels): - shift_logits = logits[..., :-1, :] - shift_labels = labels[..., 1:] - loss = optax.softmax_cross_entropy( - shift_logits, - jax.nn.one_hot(shift_labels, logits.shape[-1])) - return loss.mean() + ignore_ids = (IGNORE_TOKEN_ID, tokenizer.pad_token_id) # Define gradient update step fn def train_step(state, batch): @@ -638,7 +658,7 @@ def compute_loss(params): # Currently we don't support non-deterministic training with remat, # so train=False. This arg has no other impact. logits = state.apply_fn(**batch, params=params, train=False)[0] - loss = loss_fn(logits, labels) + loss = loss_fn(logits, labels, ignore_ids) return loss dynamic_scale = state.dynamic_scale @@ -672,7 +692,7 @@ def compute_loss(params): def eval_step(params, batch): labels = batch.pop("labels") logits = model(**batch, params=params, deterministic=True)[0] - loss = loss_fn(logits, labels) + loss = loss_fn(logits, labels, IGNORE_TOKEN_ID) # summarize metrics metrics = {"loss": loss} @@ -768,7 +788,7 @@ def eval_step(params, batch): epochs.write( f"Step... {cur_step} | " f"Loss: {train_metric['loss'].mean():.4f}, " - f"Learning Rate: {train_metric['learning_rate'].mean():.5f}, " + f"Learning Rate: {train_metric['learning_rate'].mean()}, " f"Throughput: {throughput_tokens:.2f} token/s, " f"{throughput_tflops:.2f} TFLOP/s" ) From d418e892ebbc54827e81e08efc7fab092814d463 Mon Sep 17 00:00:00 2001 From: ZYHowell Date: Thu, 4 May 2023 10:35:55 +0400 Subject: [PATCH 10/11] add weight conversion --- examples/llama_finetune/hf_jax_conversion.py | 9 +++++++-- examples/llama_finetune/run_easylm_flax.py | 5 +++-- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/examples/llama_finetune/hf_jax_conversion.py b/examples/llama_finetune/hf_jax_conversion.py index 0464bb037..b2aa63b89 100644 --- a/examples/llama_finetune/hf_jax_conversion.py +++ b/examples/llama_finetune/hf_jax_conversion.py @@ -8,6 +8,11 @@ def import_hf_model(model_name_or_path): def hf_to_jax_weight(hf_model): state_dict = hf_model.state_dict() + num_heads = hf_model.config.num_attention_heads + dim = hf_model.config.hidden_size + # inverse function of EasyLM's convert_easylm_to_hf.write_model.permute + def inv_permute(w): + return w.reshape(num_heads, 2, dim // num_heads // 2, dim).transpose(1, 2).reshape(dim, dim) jax_weights = { 'transformer': { 'wte': {'embedding': state_dict['model.embed_tokens.weight'].numpy()}, @@ -15,8 +20,8 @@ def hf_to_jax_weight(hf_model): 'h': { '%d' % (layer): { 'attention': { - 'wq': {'kernel': state_dict['model.layers.%d.self_attn.q_proj.weight' % (layer)].numpy().transpose()}, - 'wk': {'kernel': state_dict['model.layers.%d.self_attn.k_proj.weight' % (layer)].numpy().transpose()}, + 'wq': {'kernel': inv_permute(state_dict['model.layers.%d.self_attn.q_proj.weight' % (layer)]).numpy().transpose()}, + 'wk': {'kernel': inv_permute(state_dict['model.layers.%d.self_attn.k_proj.weight' % (layer)]).numpy().transpose()}, 'wv': {'kernel': state_dict['model.layers.%d.self_attn.v_proj.weight' % (layer)].numpy().transpose()}, 'wo': {'kernel': state_dict['model.layers.%d.self_attn.o_proj.weight' % (layer)].numpy().transpose()}, }, diff --git a/examples/llama_finetune/run_easylm_flax.py b/examples/llama_finetune/run_easylm_flax.py index b821705a4..0358c865b 100644 --- a/examples/llama_finetune/run_easylm_flax.py +++ b/examples/llama_finetune/run_easylm_flax.py @@ -383,6 +383,7 @@ def loss_fn(logits, labels, ignore_indices): for ignore_index in ignore_indices: new_valid = jnp.not_equal(shift_labels, ignore_index) valid = jnp.logical_and(valid, new_valid) + valid = jnp.asarray(valid, dtype=jnp.float32) valid_len = jnp.maximum(jnp.sum(valid, axis=-1), 1e-10) # OneHot and mask the ignore index. For ignore_index(-100), the whole line # in the output would be 0. @@ -391,7 +392,7 @@ def loss_fn(logits, labels, ignore_indices): log_p = jax.nn.log_softmax(shift_logits, axis=-1) # (bs, seq_len, vocab) -> (bs, seq_len) cross_entropy = jnp.sum(one_hot_labels * log_p, axis=-1) - loss = -jnp.mean(jnp.sum(cross_entropy, axis=-1) / valid_len) + loss = -jnp.mean(jnp.sum(cross_entropy * valid, axis=-1) / valid_len) return loss @@ -692,7 +693,7 @@ def compute_loss(params): def eval_step(params, batch): labels = batch.pop("labels") logits = model(**batch, params=params, deterministic=True)[0] - loss = loss_fn(logits, labels, IGNORE_TOKEN_ID) + loss = loss_fn(logits, labels, ignore_ids) # summarize metrics metrics = {"loss": loss} From 2752efc045185856b3d74a2370a0432cb61912eb Mon Sep 17 00:00:00 2001 From: ZYHowell Date: Sat, 13 May 2023 07:54:50 +0400 Subject: [PATCH 11/11] tmp --- examples/llama_finetune/monkey_patch.py | 3 +- examples/llama_finetune/test.ipynb | 1034 +++++++++++++++++++++++ 2 files changed, 1036 insertions(+), 1 deletion(-) create mode 100644 examples/llama_finetune/test.ipynb diff --git a/examples/llama_finetune/monkey_patch.py b/examples/llama_finetune/monkey_patch.py index 420667a47..c14cc783a 100644 --- a/examples/llama_finetune/monkey_patch.py +++ b/examples/llama_finetune/monkey_patch.py @@ -13,5 +13,6 @@ def init_dummy(self, *args, **kwargs): avals = jax.eval_shape(partial(self._backup_init, **kwargs), *args) return jax.tree_util.tree_map(lambda x: jnp.full(x.shape, 1e-8, x.dtype), avals) - FlaxLLaMAForCausalLMModule._backup_init = FlaxLLaMAForCausalLMModule.init + if not hasattr(FlaxLLaMAForCausalLMModule, "_backup_init"): + FlaxLLaMAForCausalLMModule._backup_init = FlaxLLaMAForCausalLMModule.init FlaxLLaMAForCausalLMModule.init = init_dummy \ No newline at end of file diff --git a/examples/llama_finetune/test.ipynb b/examples/llama_finetune/test.ipynb new file mode 100644 index 000000000..4ae796dc2 --- /dev/null +++ b/examples/llama_finetune/test.ipynb @@ -0,0 +1,1034 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/yonghao.zhuang/.local/lib/python3.8/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n", + "INFO:__main__:Training/evaluation parameters TrainingArguments(output_dir='./output', overwrite_output_dir=False, do_train=True, do_eval=False, per_device_train_batch_size=8, per_device_eval_batch_size=8, num_micro_batches=1, operator_parallel=1, pipeline_parallel=1, use_remat=True, learning_rate=5e-05, weight_decay=0.0, adam_beta1=0.9, adam_beta2=0.999, adam_epsilon=1e-08, adafactor=False, num_train_epochs=3.0, warmup_ratio=0.0, logging_steps=500, save_steps=500, eval_steps=None, seed=42, push_to_hub=False, hub_model_id=None, hub_token=None)\n", + "Model config LLaMAConfig {\n", + " \"attn_pdrop\": 0.0,\n", + " \"bos_token_id\": 0,\n", + " \"embd_pdrop\": 0.0,\n", + " \"eos_token_id\": 1,\n", + " \"fcm_max_ratio\": 0.0,\n", + " \"fcm_min_ratio\": 0.0,\n", + " \"gradient_checkpointing\": \"nothing_saveable\",\n", + " \"hidden_size\": 4096,\n", + " \"initializer_range\": 0.02,\n", + " \"intermediate_size\": 11008,\n", + " \"max_sequence_length\": 2048,\n", + " \"model_type\": \"llama\",\n", + " \"num_attention_heads\": 32,\n", + " \"num_hidden_layers\": 32,\n", + " \"resid_pdrop\": 0.0,\n", + " \"rms_norm_eps\": 1e-06,\n", + " \"tie_word_embeddings\": false,\n", + " \"transformers_version\": \"4.28.1\",\n", + " \"use_cache\": true,\n", + " \"vocab_size\": 32000\n", + "}\n", + "\n", + "loading file tokenizer.model\n", + "loading file added_tokens.json\n", + "loading file special_tokens_map.json\n", + "loading file tokenizer_config.json\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[StreamExecutorGpuDevice(id=0, process_index=0), StreamExecutorGpuDevice(id=1, process_index=0), StreamExecutorGpuDevice(id=2, process_index=0), StreamExecutorGpuDevice(id=3, process_index=0), StreamExecutorGpuDevice(id=4, process_index=0), StreamExecutorGpuDevice(id=5, process_index=0), StreamExecutorGpuDevice(id=6, process_index=0), StreamExecutorGpuDevice(id=7, process_index=0)]\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Generate config GenerationConfig {\n", + " \"_from_model_config\": true,\n", + " \"bos_token_id\": 1,\n", + " \"eos_token_id\": 2,\n", + " \"transformers_version\": \"4.28.1\"\n", + "}\n", + "\n", + "2023-05-13 07:31:56.335021: E external/org_tensorflow/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc:2156] Execution of replica 0 failed: INTERNAL: CustomCall failed: jaxlib/cuda/cuda_prng_kernels.cc:32: operation cudaGetLastError() failed: out of memory\n" + ] + }, + { + "ename": "XlaRuntimeError", + "evalue": "INTERNAL: CustomCall failed: jaxlib/cuda/cuda_prng_kernels.cc:32: operation cudaGetLastError() failed: out of memory", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mXlaRuntimeError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[1], line 452\u001b[0m\n\u001b[1;32m 450\u001b[0m \u001b[39m# Monkey patch the model's init to init_dummy\u001b[39;00m\n\u001b[1;32m 451\u001b[0m do_monkey_patch()\n\u001b[0;32m--> 452\u001b[0m model \u001b[39m=\u001b[39m FlaxLLaMAForCausalLM(config, dummy_input_shape, dtype\u001b[39m=\u001b[39;49mdtype)\n\u001b[1;32m 453\u001b[0m hf_model \u001b[39m=\u001b[39m transformers\u001b[39m.\u001b[39mAutoModelForCausalLM\u001b[39m.\u001b[39mfrom_pretrained(\n\u001b[1;32m 454\u001b[0m model_args\u001b[39m.\u001b[39mmodel_name_or_path,\n\u001b[1;32m 455\u001b[0m torch_dtype\u001b[39m=\u001b[39mtorch_dtype\n\u001b[1;32m 456\u001b[0m )\n\u001b[1;32m 457\u001b[0m loaded_params \u001b[39m=\u001b[39m hf_to_jax_weight(hf_model)\n", + "File \u001b[0;32m~/alpa/EasyLM/EasyLM/models/llama/llama_model.py:644\u001b[0m, in \u001b[0;36mFlaxLLaMAPreTrainedModel.__init__\u001b[0;34m(self, config, input_shape, seed, dtype, _do_init, **kwargs)\u001b[0m\n\u001b[1;32m 634\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39m__init__\u001b[39m(\n\u001b[1;32m 635\u001b[0m \u001b[39mself\u001b[39m,\n\u001b[1;32m 636\u001b[0m config: LLaMAConfig,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 641\u001b[0m \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mkwargs,\n\u001b[1;32m 642\u001b[0m ):\n\u001b[1;32m 643\u001b[0m module \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mmodule_class(config\u001b[39m=\u001b[39mconfig, dtype\u001b[39m=\u001b[39mdtype, \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mkwargs)\n\u001b[0;32m--> 644\u001b[0m \u001b[39msuper\u001b[39;49m()\u001b[39m.\u001b[39;49m\u001b[39m__init__\u001b[39;49m(config, module, input_shape\u001b[39m=\u001b[39;49minput_shape, seed\u001b[39m=\u001b[39;49mseed, dtype\u001b[39m=\u001b[39;49mdtype, _do_init\u001b[39m=\u001b[39;49m_do_init)\n", + "File \u001b[0;32m~/.local/lib/python3.8/site-packages/transformers/modeling_flax_utils.py:209\u001b[0m, in \u001b[0;36mFlaxPreTrainedModel.__init__\u001b[0;34m(self, config, module, input_shape, seed, dtype, _do_init)\u001b[0m\n\u001b[1;32m 205\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_is_initialized \u001b[39m=\u001b[39m _do_init\n\u001b[1;32m 207\u001b[0m \u001b[39mif\u001b[39;00m _do_init:\n\u001b[1;32m 208\u001b[0m \u001b[39m# randomly initialized parameters\u001b[39;00m\n\u001b[0;32m--> 209\u001b[0m random_params \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49minit_weights(\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mkey, input_shape)\n\u001b[1;32m 210\u001b[0m params_shape_tree \u001b[39m=\u001b[39m jax\u001b[39m.\u001b[39meval_shape(\u001b[39mlambda\u001b[39;00m params: params, random_params)\n\u001b[1;32m 211\u001b[0m \u001b[39melse\u001b[39;00m:\n", + "File \u001b[0;32m~/alpa/EasyLM/EasyLM/models/llama/llama_model.py:651\u001b[0m, in \u001b[0;36mFlaxLLaMAPreTrainedModel.init_weights\u001b[0;34m(self, rng, input_shape, params)\u001b[0m\n\u001b[1;32m 649\u001b[0m attention_mask \u001b[39m=\u001b[39m jnp\u001b[39m.\u001b[39mones_like(input_ids)\n\u001b[1;32m 650\u001b[0m position_ids \u001b[39m=\u001b[39m jnp\u001b[39m.\u001b[39mbroadcast_to(jnp\u001b[39m.\u001b[39marange(jnp\u001b[39m.\u001b[39matleast_2d(input_ids)\u001b[39m.\u001b[39mshape[\u001b[39m-\u001b[39m\u001b[39m1\u001b[39m]), input_shape)\n\u001b[0;32m--> 651\u001b[0m params_rng, dropout_rng \u001b[39m=\u001b[39m jax\u001b[39m.\u001b[39;49mrandom\u001b[39m.\u001b[39;49msplit(rng)\n\u001b[1;32m 652\u001b[0m rngs \u001b[39m=\u001b[39m {\u001b[39m\"\u001b[39m\u001b[39mparams\u001b[39m\u001b[39m\"\u001b[39m: params_rng, \u001b[39m\"\u001b[39m\u001b[39mdropout\u001b[39m\u001b[39m\"\u001b[39m: dropout_rng}\n\u001b[1;32m 654\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mconfig\u001b[39m.\u001b[39madd_cross_attention:\n", + "File \u001b[0;32m~/alpa/alpa/third_party/jax/jax/_src/random.py:213\u001b[0m, in \u001b[0;36msplit\u001b[0;34m(key, num)\u001b[0m\n\u001b[1;32m 202\u001b[0m \u001b[39m\u001b[39m\u001b[39m\"\"\"Splits a PRNG key into `num` new keys by adding a leading axis.\u001b[39;00m\n\u001b[1;32m 203\u001b[0m \n\u001b[1;32m 204\u001b[0m \u001b[39mArgs:\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 210\u001b[0m \u001b[39m An array-like object of `num` new PRNG keys.\u001b[39;00m\n\u001b[1;32m 211\u001b[0m \u001b[39m\"\"\"\u001b[39;00m\n\u001b[1;32m 212\u001b[0m key, wrapped \u001b[39m=\u001b[39m _check_prng_key(key)\n\u001b[0;32m--> 213\u001b[0m \u001b[39mreturn\u001b[39;00m _return_prng_keys(wrapped, _split(key, num))\n", + "File \u001b[0;32m~/alpa/alpa/third_party/jax/jax/_src/random.py:199\u001b[0m, in \u001b[0;36m_split\u001b[0;34m(key, num)\u001b[0m\n\u001b[1;32m 196\u001b[0m \u001b[39mif\u001b[39;00m key\u001b[39m.\u001b[39mndim:\n\u001b[1;32m 197\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mTypeError\u001b[39;00m(\u001b[39m\"\u001b[39m\u001b[39msplit accepts a single key, but was given a key array of\u001b[39m\u001b[39m\"\u001b[39m\n\u001b[1;32m 198\u001b[0m \u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39mshape \u001b[39m\u001b[39m{\u001b[39;00mkey\u001b[39m.\u001b[39mshape\u001b[39m}\u001b[39;00m\u001b[39m != (). Use jax.vmap for batching.\u001b[39m\u001b[39m\"\u001b[39m)\n\u001b[0;32m--> 199\u001b[0m \u001b[39mreturn\u001b[39;00m prng\u001b[39m.\u001b[39;49mrandom_split(key, count\u001b[39m=\u001b[39;49mnum)\n", + "File \u001b[0;32m~/alpa/alpa/third_party/jax/jax/_src/prng.py:624\u001b[0m, in \u001b[0;36mrandom_split\u001b[0;34m(keys, count)\u001b[0m\n\u001b[1;32m 623\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mrandom_split\u001b[39m(keys, count):\n\u001b[0;32m--> 624\u001b[0m \u001b[39mreturn\u001b[39;00m random_split_p\u001b[39m.\u001b[39;49mbind(keys, count\u001b[39m=\u001b[39;49mcount)\n", + "File \u001b[0;32m~/alpa/alpa/third_party/jax/jax/core.py:328\u001b[0m, in \u001b[0;36mPrimitive.bind\u001b[0;34m(self, *args, **params)\u001b[0m\n\u001b[1;32m 325\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mbind\u001b[39m(\u001b[39mself\u001b[39m, \u001b[39m*\u001b[39margs, \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mparams):\n\u001b[1;32m 326\u001b[0m \u001b[39massert\u001b[39;00m (\u001b[39mnot\u001b[39;00m config\u001b[39m.\u001b[39mjax_enable_checks \u001b[39mor\u001b[39;00m\n\u001b[1;32m 327\u001b[0m \u001b[39mall\u001b[39m(\u001b[39misinstance\u001b[39m(arg, Tracer) \u001b[39mor\u001b[39;00m valid_jaxtype(arg) \u001b[39mfor\u001b[39;00m arg \u001b[39min\u001b[39;00m args)), args\n\u001b[0;32m--> 328\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mbind_with_trace(find_top_trace(args), args, params)\n", + "File \u001b[0;32m~/alpa/alpa/third_party/jax/jax/core.py:331\u001b[0m, in \u001b[0;36mPrimitive.bind_with_trace\u001b[0;34m(self, trace, args, params)\u001b[0m\n\u001b[1;32m 330\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mbind_with_trace\u001b[39m(\u001b[39mself\u001b[39m, trace, args, params):\n\u001b[0;32m--> 331\u001b[0m out \u001b[39m=\u001b[39m trace\u001b[39m.\u001b[39;49mprocess_primitive(\u001b[39mself\u001b[39;49m, \u001b[39mmap\u001b[39;49m(trace\u001b[39m.\u001b[39;49mfull_raise, args), params)\n\u001b[1;32m 332\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mmap\u001b[39m(full_lower, out) \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mmultiple_results \u001b[39melse\u001b[39;00m full_lower(out)\n", + "File \u001b[0;32m~/alpa/alpa/third_party/jax/jax/core.py:698\u001b[0m, in \u001b[0;36mEvalTrace.process_primitive\u001b[0;34m(self, primitive, tracers, params)\u001b[0m\n\u001b[1;32m 697\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mprocess_primitive\u001b[39m(\u001b[39mself\u001b[39m, primitive, tracers, params):\n\u001b[0;32m--> 698\u001b[0m \u001b[39mreturn\u001b[39;00m primitive\u001b[39m.\u001b[39;49mimpl(\u001b[39m*\u001b[39;49mtracers, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mparams)\n", + "File \u001b[0;32m~/alpa/alpa/third_party/jax/jax/_src/prng.py:636\u001b[0m, in \u001b[0;36mrandom_split_impl\u001b[0;34m(keys, count)\u001b[0m\n\u001b[1;32m 634\u001b[0m \u001b[39m@random_split_p\u001b[39m\u001b[39m.\u001b[39mdef_impl\n\u001b[1;32m 635\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mrandom_split_impl\u001b[39m(keys, \u001b[39m*\u001b[39m, count):\n\u001b[0;32m--> 636\u001b[0m base_arr \u001b[39m=\u001b[39m random_split_impl_base(\n\u001b[1;32m 637\u001b[0m keys\u001b[39m.\u001b[39;49mimpl, keys\u001b[39m.\u001b[39;49munsafe_raw_array(), keys\u001b[39m.\u001b[39;49mndim, count\u001b[39m=\u001b[39;49mcount)\n\u001b[1;32m 638\u001b[0m \u001b[39mreturn\u001b[39;00m PRNGKeyArray(keys\u001b[39m.\u001b[39mimpl, base_arr)\n", + "File \u001b[0;32m~/alpa/alpa/third_party/jax/jax/_src/prng.py:642\u001b[0m, in \u001b[0;36mrandom_split_impl_base\u001b[0;34m(impl, base_arr, keys_ndim, count)\u001b[0m\n\u001b[1;32m 640\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mrandom_split_impl_base\u001b[39m(impl, base_arr, keys_ndim, \u001b[39m*\u001b[39m, count):\n\u001b[1;32m 641\u001b[0m split \u001b[39m=\u001b[39m iterated_vmap_unary(keys_ndim, \u001b[39mlambda\u001b[39;00m k: impl\u001b[39m.\u001b[39msplit(k, count))\n\u001b[0;32m--> 642\u001b[0m \u001b[39mreturn\u001b[39;00m split(base_arr)\n", + "File \u001b[0;32m~/alpa/alpa/third_party/jax/jax/_src/prng.py:641\u001b[0m, in \u001b[0;36mrandom_split_impl_base..\u001b[0;34m(k)\u001b[0m\n\u001b[1;32m 640\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mrandom_split_impl_base\u001b[39m(impl, base_arr, keys_ndim, \u001b[39m*\u001b[39m, count):\n\u001b[0;32m--> 641\u001b[0m split \u001b[39m=\u001b[39m iterated_vmap_unary(keys_ndim, \u001b[39mlambda\u001b[39;00m k: impl\u001b[39m.\u001b[39;49msplit(k, count))\n\u001b[1;32m 642\u001b[0m \u001b[39mreturn\u001b[39;00m split(base_arr)\n", + "File \u001b[0;32m~/alpa/alpa/third_party/jax/jax/_src/prng.py:1033\u001b[0m, in \u001b[0;36mthreefry_split\u001b[0;34m(key, num)\u001b[0m\n\u001b[1;32m 1032\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mthreefry_split\u001b[39m(key: jnp\u001b[39m.\u001b[39mndarray, num: \u001b[39mint\u001b[39m) \u001b[39m-\u001b[39m\u001b[39m>\u001b[39m jnp\u001b[39m.\u001b[39mndarray:\n\u001b[0;32m-> 1033\u001b[0m \u001b[39mreturn\u001b[39;00m _threefry_split(key, \u001b[39mint\u001b[39;49m(num))\n", + " \u001b[0;31m[... skipping hidden 6 frame]\u001b[0m\n", + "File \u001b[0;32m~/alpa/alpa/third_party/jax/jax/_src/dispatch.py:878\u001b[0m, in \u001b[0;36m_execute_compiled\u001b[0;34m(name, compiled, input_handler, output_buffer_counts, result_handler, has_unordered_effects, ordered_effects, kept_var_idx, has_host_callbacks, *args)\u001b[0m\n\u001b[1;32m 876\u001b[0m runtime_token \u001b[39m=\u001b[39m \u001b[39mNone\u001b[39;00m\n\u001b[1;32m 877\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[0;32m--> 878\u001b[0m out_flat \u001b[39m=\u001b[39m compiled\u001b[39m.\u001b[39;49mexecute(in_flat)\n\u001b[1;32m 879\u001b[0m check_special(name, out_flat)\n\u001b[1;32m 880\u001b[0m out_bufs \u001b[39m=\u001b[39m unflatten(out_flat, output_buffer_counts)\n", + "\u001b[0;31mXlaRuntimeError\u001b[0m: INTERNAL: CustomCall failed: jaxlib/cuda/cuda_prng_kernels.cc:32: operation cudaGetLastError() failed: out of memory" + ] + } + ], + "source": [ + "\"\"\"\n", + "Pre-training/Fine-tuning the library models for causal language modeling (GPT, GPT-2, CTRL, ...) on a text file or a dataset.\n", + "\n", + "Here is the full list of checkpoints on the hub that can be fine-tuned by this script:\n", + "https://huggingface.co/models?filter=text-generation\n", + "\"\"\"\n", + "# You can also adapt this script on your own causal language modeling task. Pointers for this are left as comments.\n", + "\n", + "import json\n", + "import logging\n", + "import math\n", + "import os\n", + "import sys\n", + "import time\n", + "from dataclasses import asdict, dataclass, field\n", + "from enum import Enum\n", + "import functools\n", + "from pathlib import Path\n", + "from typing import Callable, Optional\n", + "\n", + "import datasets\n", + "import numpy as np\n", + "from datasets import Dataset\n", + "from tqdm import tqdm\n", + "\n", + "import alpa\n", + "from alpa.model.model_util import DynamicScale, TrainState\n", + "from alpa import ManualShardingOption\n", + "import jax\n", + "from jax.experimental.pjit import PartitionSpec\n", + "import jax.numpy as jnp\n", + "import optax\n", + "import transformers\n", + "from transformers.utils import get_full_repo_name, send_example_telemetry\n", + "import tensorflow as tf\n", + "from flax import traverse_util\n", + "from optax import tree_map_params\n", + "from huggingface_hub import Repository\n", + "from transformers import (\n", + " FLAX_MODEL_FOR_CAUSAL_LM_MAPPING,\n", + " HfArgumentParser,\n", + " is_tensorboard_available,\n", + " set_seed,\n", + ")\n", + "\n", + "import torch\n", + "\n", + "# alpa.init(cluster=\"ray\")\n", + "\n", + "# tf.config.experimental.set_visible_devices([], 'GPU')\n", + "\n", + "from EasyLM.models.llama.llama_model import (\n", + " LLaMAConfig, FlaxLLaMAForCausalLM\n", + ")\n", + "\n", + "from hf_datasets import make_supervised_data_module\n", + "from hf_jax_conversion import hf_to_jax_weight\n", + "from monkey_patch import do_monkey_patch\n", + "\n", + "logger = logging.getLogger(__name__)\n", + "\n", + "MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_CAUSAL_LM_MAPPING.keys())\n", + "MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)\n", + "IGNORE_TOKEN_ID = -100\n", + "print(jax.devices())\n", + "\n", + "\n", + "@dataclass\n", + "class TrainingArguments:\n", + " \"\"\"A subset of Huggingface's training arguments\"\"\"\n", + " output_dir: str = field(\n", + " default=\"./output\",\n", + " metadata={\"help\": \"The output directory where the model predictions and checkpoints will be written.\"},\n", + " )\n", + " overwrite_output_dir: bool = field(\n", + " default=False,\n", + " metadata={\n", + " \"help\": (\n", + " \"Overwrite the content of the output directory. \"\n", + " \"Use this to continue training if output_dir points to a checkpoint directory.\"\n", + " )\n", + " },\n", + " )\n", + " do_train: bool = field(default=True, metadata={\"help\": \"Whether to run training.\"})\n", + " do_eval: bool = field(default=False, metadata={\"help\": \"Whether to run eval on the dev set.\"})\n", + " per_device_train_batch_size: int = field(\n", + " default=8, metadata={\"help\": \"Batch size per GPU/TPU core/CPU for training.\"}\n", + " )\n", + " per_device_eval_batch_size: int = field(\n", + " default=8, metadata={\"help\": \"Batch size per GPU/TPU core/CPU for evaluation.\"}\n", + " )\n", + " num_micro_batches: int = field(default=1, metadata={\"help\": \"The number of micro batches for gradient accumulation.\"})\n", + " operator_parallel: int = field(default=1, metadata={\"help\": \"The degree of operator model parallelism.\"})\n", + " pipeline_parallel: int = field(default=1, metadata={\"help\": \"The degree of pipeline model parallelism.\"})\n", + " use_remat: bool = field(default=True, metadata={\"help\": \"Whether or not to use gradient rematerilization/gradient checkpointing.\"})\n", + " learning_rate: float = field(default=5e-5, metadata={\"help\": \"The initial learning rate for AdamW.\"})\n", + " weight_decay: float = field(default=0.0, metadata={\"help\": \"Weight decay for AdamW if we apply some.\"})\n", + " adam_beta1: float = field(default=0.9, metadata={\"help\": \"Beta1 for AdamW optimizer\"})\n", + " adam_beta2: float = field(default=0.999, metadata={\"help\": \"Beta2 for AdamW optimizer\"})\n", + " adam_epsilon: float = field(default=1e-8, metadata={\"help\": \"Epsilon for AdamW optimizer.\"})\n", + " adafactor: bool = field(default=False, metadata={\"help\": \"Whether or not to replace AdamW by Adafactor.\"})\n", + " num_train_epochs: float = field(default=3.0, metadata={\"help\": \"Total number of training epochs to perform.\"})\n", + " warmup_ratio: float = field(default=0.0, metadata={\"help\": \"Linear warmup over a ratio of overall steps.\"})\n", + " logging_steps: int = field(default=500, metadata={\"help\": \"Log every X updates steps.\"})\n", + " save_steps: int = field(default=500, metadata={\"help\": \"Save checkpoint every X updates steps.\"})\n", + " eval_steps: int = field(default=None, metadata={\"help\": \"Run an evaluation every X steps.\"})\n", + " seed: int = field(default=42, metadata={\"help\": \"Random seed that will be set at the beginning of training.\"})\n", + " push_to_hub: bool = field(\n", + " default=False, metadata={\"help\": \"Whether or not to upload the trained model to the model hub after training.\"}\n", + " )\n", + " hub_model_id: str = field(\n", + " default=None, metadata={\"help\": \"The name of the repository to keep in sync with the local `output_dir`.\"}\n", + " )\n", + " hub_token: str = field(default=None, metadata={\"help\": \"The token to use to push to the Model Hub.\"})\n", + "\n", + " def __post_init__(self):\n", + " if self.output_dir is not None:\n", + " self.output_dir = os.path.expanduser(self.output_dir)\n", + "\n", + " def to_dict(self):\n", + " \"\"\"\n", + " Serializes this instance while replace `Enum` by their values (for JSON serialization support). It obfuscates\n", + " the token values by removing their value.\n", + " \"\"\"\n", + " d = asdict(self)\n", + " for k, v in d.items():\n", + " if isinstance(v, Enum):\n", + " d[k] = v.value\n", + " if isinstance(v, list) and len(v) > 0 and isinstance(v[0], Enum):\n", + " d[k] = [x.value for x in v]\n", + " if k.endswith(\"_token\"):\n", + " d[k] = f\"<{k.upper()}>\"\n", + " return d\n", + "\n", + "\n", + "@dataclass\n", + "class ModelArguments:\n", + " \"\"\"\n", + " Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.\n", + " \"\"\"\n", + "\n", + " model_name_or_path: Optional[str] = field(\n", + " default=\"/data/yonghao.zhuang/vicuna-7b-v1.2-b128l2\",\n", + " metadata={\n", + " \"help\": (\n", + " \"The model checkpoint for weights initialization.Don't set if you want to train a model from scratch.\"\n", + " )\n", + " },\n", + " )\n", + " model_type: Optional[str] = field(\n", + " default=None,\n", + " metadata={\"help\": \"If training from scratch, pass a model type from the list: \" + \", \".join(MODEL_TYPES)},\n", + " )\n", + " config_name: Optional[str] = field(\n", + " default=None, metadata={\"help\": \"Pretrained config name or path if not the same as model_name\"}\n", + " )\n", + " tokenizer_name: Optional[str] = field(\n", + " default=None, metadata={\"help\": \"Pretrained tokenizer name or path if not the same as model_name\"}\n", + " )\n", + " cache_dir: Optional[str] = field(\n", + " default=None, metadata={\"help\": \"Where do you want to store the pretrained models downloaded from s3\"}\n", + " )\n", + " use_fast_tokenizer: bool = field(\n", + " default=True,\n", + " metadata={\"help\": \"Whether to use one of the fast tokenizer (backed by the tokenizers library) or not.\"},\n", + " )\n", + " dtype: Optional[str] = field(\n", + " default=\"float16\",\n", + " metadata={\n", + " \"help\": (\n", + " \"Floating-point format in which the model weights should be initialized and trained. Choose one of\"\n", + " \" `[float32, float16, bfloat16]`.\"\n", + " )\n", + " },\n", + " )\n", + " use_auth_token: bool = field(\n", + " default=False,\n", + " metadata={\n", + " \"help\": (\n", + " \"Will use the token generated when running `transformers-cli login` (necessary to use this script \"\n", + " \"with private models).\"\n", + " )\n", + " },\n", + " )\n", + "\n", + "\n", + "@dataclass\n", + "class DataTrainingArguments:\n", + " \"\"\"\n", + " Arguments pertaining to what data we are going to input our model for training and eval.\n", + " \"\"\"\n", + "\n", + " dataset_name: Optional[str] = field(\n", + " default=\"/home/yonghao.zhuang/alpa/files/sharegpt/sharegpt_20230422_clean_lang_split_identity.json\", metadata={\"help\": \"The name of the dataset to use (via the datasets library).\"}\n", + " )\n", + " dataset_config_name: Optional[str] = field(\n", + " default=None, metadata={\"help\": \"The configuration name of the dataset to use (via the datasets library).\"}\n", + " )\n", + " train_file: Optional[str] = field(default=None, metadata={\"help\": \"The input training data file (a text file).\"})\n", + " validation_file: Optional[str] = field(\n", + " default=None,\n", + " metadata={\"help\": \"An optional input evaluation data file to evaluate the perplexity on (a text file).\"},\n", + " )\n", + " max_train_samples: Optional[int] = field(\n", + " default=None,\n", + " metadata={\n", + " \"help\": (\n", + " \"For debugging purposes or quicker training, truncate the number of training examples to this \"\n", + " \"value if set.\"\n", + " )\n", + " },\n", + " )\n", + " max_eval_samples: Optional[int] = field(\n", + " default=None,\n", + " metadata={\n", + " \"help\": (\n", + " \"For debugging purposes or quicker training, truncate the number of evaluation examples to this \"\n", + " \"value if set.\"\n", + " )\n", + " },\n", + " )\n", + " overwrite_cache: bool = field(\n", + " default=False, metadata={\"help\": \"Overwrite the cached training and evaluation sets\"}\n", + " )\n", + " validation_split_percentage: Optional[int] = field(\n", + " default=5,\n", + " metadata={\n", + " \"help\": \"The percentage of the train set used as validation set in case there's no validation split\"\n", + " },\n", + " )\n", + " block_size: Optional[int] = field(\n", + " default=1024,\n", + " metadata={\n", + " \"help\": (\n", + " \"Optional input sequence length after tokenization. \"\n", + " \"The training dataset will be truncated in block of this size for training. \"\n", + " \"Default to the model max input length for single sentence inputs (take into account special tokens).\"\n", + " )\n", + " },\n", + " )\n", + " overwrite_cache: bool = field(\n", + " default=False, metadata={\"help\": \"Overwrite the cached training and evaluation sets\"}\n", + " )\n", + " preprocessing_num_workers: Optional[int] = field(\n", + " default=None,\n", + " metadata={\"help\": \"The number of processes to use for the preprocessing.\"},\n", + " )\n", + " keep_linebreaks: bool = field(\n", + " default=True, metadata={\"help\": \"Whether to keep line breaks when using TXT files or not.\"}\n", + " )\n", + "\n", + " def __post_init__(self):\n", + " if self.dataset_name is None and self.train_file is None and self.validation_file is None:\n", + " raise ValueError(\"Need either a dataset name or a training/validation file.\")\n", + " else:\n", + " if self.train_file is not None:\n", + " extension = self.train_file.split(\".\")[-1]\n", + " assert extension in [\"csv\", \"json\", \"txt\"], \"`train_file` should be a csv, a json or a txt file.\"\n", + " if self.validation_file is not None:\n", + " extension = self.validation_file.split(\".\")[-1]\n", + " assert extension in [\"csv\", \"json\", \"txt\"], \"`validation_file` should be a csv, a json or a txt file.\"\n", + "\n", + "\n", + "def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int,\n", + " min_batch_size: int, shuffle: bool = False):\n", + " \"\"\"\n", + " Returns batches of size `batch_size` from truncated `dataset`, sharded over all local devices.\n", + " Shuffle batches if `shuffle` is `True`.\n", + " \"\"\"\n", + " if len(dataset) < batch_size:\n", + " assert len(dataset) >= min_batch_size\n", + " batch_size = len(dataset) // min_batch_size * min_batch_size\n", + "\n", + " data_collator = transformers.DefaultDataCollator(\"np\")\n", + " tf_dataset = dataset.to_tf_dataset(batch_size=batch_size,\n", + " columns=dataset.column_names,\n", + " collate_fn=data_collator,\n", + " shuffle=shuffle,\n", + " drop_remainder=True)\n", + "\n", + " for batch in tf_dataset:\n", + " batch = {k: v._numpy() for k, v in batch.items()}\n", + " yield batch\n", + "\n", + "\n", + "def write_train_metric(summary_writer, train_metrics, train_time, step):\n", + " summary_writer.scalar(\"train_time\", train_time, step)\n", + "\n", + " train_metrics = alpa.util.get_metrics(train_metrics)\n", + " for key, vals in train_metrics.items():\n", + " tag = f\"train_{key}\"\n", + " for i, val in enumerate(vals):\n", + " summary_writer.scalar(tag, val, step - len(vals) + i + 1)\n", + "\n", + "\n", + "def write_eval_metric(summary_writer, eval_metrics, step):\n", + " for metric_name, value in eval_metrics.items():\n", + " summary_writer.scalar(f\"eval_{metric_name}\", value, step)\n", + "\n", + "\n", + "def create_learning_rate_fn(\n", + " train_ds_size: int, train_batch_size: int, num_train_epochs: int, warmup_ratio: float, learning_rate: float\n", + ") -> Callable[[int], jnp.array]:\n", + " \"\"\"Returns a linear warmup, linear_decay learning rate function.\"\"\"\n", + " steps_per_epoch = train_ds_size // train_batch_size\n", + " num_train_steps = steps_per_epoch * num_train_epochs\n", + " num_warmup_steps = int(num_train_steps * warmup_ratio)\n", + " warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps)\n", + " decay_fn = optax.cosine_decay_schedule(\n", + " init_value=learning_rate, decay_steps=num_train_steps - num_warmup_steps\n", + " )\n", + " schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps])\n", + " return schedule_fn\n", + "\n", + "\n", + "def llama_manual_sharding(num_layers, state: TrainState):\n", + " # TODO: when rebased to jax 0.4.6, use the tree_map_with_path\n", + " param_partition = {\n", + " 'transformer': {\n", + " 'wte': {'embedding': PartitionSpec(\"mp\", None)},\n", + " 'ln_f': {'kernel': PartitionSpec(None)},\n", + " 'h': {\n", + " '%d' % (layer): {\n", + " 'attention': {\n", + " 'wq': {'kernel': PartitionSpec(None, \"mp\")},\n", + " 'wk': {'kernel': PartitionSpec(None, \"mp\")},\n", + " 'wv': {'kernel': PartitionSpec(None, \"mp\")},\n", + " 'wo': {'kernel': PartitionSpec(\"mp\", None)},\n", + " },\n", + " 'feed_forward': {\n", + " 'w1': {'kernel': PartitionSpec(None, \"mp\")},\n", + " 'w2': {'kernel': PartitionSpec(\"mp\", None)},\n", + " 'w3': {'kernel': PartitionSpec(None, \"mp\")},\n", + " },\n", + " 'attention_norm': {'kernel': PartitionSpec(None)},\n", + " 'ffn_norm': {'kernel': PartitionSpec(None)},\n", + " }\n", + " for layer in range(num_layers)},\n", + " },\n", + " 'lm_head': {'kernel': PartitionSpec(None, \"mp\")},\n", + " }\n", + " replicate = lambda x : jax.tree_util.tree_map(lambda _: PartitionSpec(None), x)\n", + " opt_state = tree_map_params(state.tx, lambda _, spec: spec, state.opt_state,\n", + " param_partition, transform_non_params=lambda _: PartitionSpec(None))\n", + " manual_partition = TrainState(\n", + " step=PartitionSpec(None),\n", + " params=param_partition,\n", + " master_copy=param_partition if state.master_copy else None,\n", + " dynamic_scale=replicate(state.dynamic_scale),\n", + " tx=state.tx,\n", + " apply_fn=state.apply_fn,\n", + " opt_state=opt_state)\n", + " return manual_partition\n", + "\n", + "\n", + "# TODO: smoothing factor\n", + "def loss_fn(logits, labels, ignore_indices):\n", + " # Shift logits\n", + " shift_logits = logits[..., :-1, :]\n", + " shift_labels = labels[..., 1:]\n", + " # Handle the ignore index: compute the valid first\n", + " valid = jnp.full(shift_labels.shape, True)\n", + " for ignore_index in ignore_indices:\n", + " new_valid = jnp.not_equal(shift_labels, ignore_index)\n", + " valid = jnp.logical_and(valid, new_valid)\n", + " valid = jnp.asarray(valid, dtype=jnp.float32)\n", + " valid_len = jnp.maximum(jnp.sum(valid, axis=-1), 1e-10)\n", + " # OneHot and mask the ignore index. For ignore_index(-100), the whole line\n", + " # in the output would be 0.\n", + " one_hot_labels = jax.nn.one_hot(shift_labels, shift_logits.shape[-1])\n", + " # Compute the softmax loss\n", + " log_p = jax.nn.log_softmax(shift_logits, axis=-1)\n", + " # (bs, seq_len, vocab) -> (bs, seq_len)\n", + " cross_entropy = jnp.sum(one_hot_labels * log_p, axis=-1)\n", + " loss = -jnp.mean(jnp.sum(cross_entropy * valid, axis=-1) / valid_len)\n", + " return loss\n", + "\n", + "\n", + "# See all possible arguments in src/transformers/training_args.py\n", + "# or by passing the --help flag to this script.\n", + "# We now keep distinct sets of args, for a cleaner separation of concerns.\n", + "\n", + "model_args, data_args, training_args = ModelArguments(), DataTrainingArguments(), TrainingArguments()\n", + "\n", + "if (\n", + " os.path.exists(training_args.output_dir)\n", + " and os.listdir(training_args.output_dir)\n", + " and training_args.do_train\n", + " and not training_args.overwrite_output_dir\n", + "):\n", + " raise ValueError(\n", + " f\"Output directory ({training_args.output_dir}) already exists and is not empty.\"\n", + " \"Use --overwrite_output_dir to overcome.\"\n", + " )\n", + "\n", + "# Make one log on every process with the configuration for debugging.\n", + "logging.basicConfig(\n", + " format=\"%(asctime)s - %(levelname)s - %(name)s - %(message)s\",\n", + " datefmt=\"%m/%d/%Y %H:%M:%S\",\n", + " level=logging.INFO,\n", + ")\n", + "# Setup logging, we only want one process per machine to log things on the screen.\n", + "logger.setLevel(logging.INFO)\n", + "datasets.utils.logging.set_verbosity_warning()\n", + "transformers.utils.logging.set_verbosity_info()\n", + "\n", + "# Set the verbosity to info of the Transformers logger (on main process only):\n", + "logger.info(f\"Training/evaluation parameters {training_args}\")\n", + "\n", + "# Set seed before initializing model.\n", + "set_seed(training_args.seed)\n", + "\n", + "# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at\n", + "# https://huggingface.co/docs/datasets/loading_datasets.html.\n", + "\n", + "# Load pretrained model and tokenizer\n", + "\n", + "# Distributed training:\n", + "# The .from_pretrained methods guarantee that only one local process can concurrently\n", + "# download model & vocab.\n", + "config = LLaMAConfig.load_config('7b')\n", + "if model_args.dtype == \"float16\":\n", + " dtype = jnp.float16\n", + " torch_dtype = torch.float16\n", + "elif model_args.dtype == \"float32\":\n", + " dtype = jnp.float32\n", + " torch_dtype = torch.float32\n", + "elif model_args.dtype == \"bfloat16\":\n", + " dtype = jnp.bfloat16\n", + " torch_dtype = torch.bfloat16\n", + "else:\n", + " raise ValueError(f\"{model_args.dtype} unsupported\")\n", + "# TODO: set the correct remat policy.\n", + "\n", + "tokenizer = transformers.AutoTokenizer.from_pretrained(\n", + " model_args.model_name_or_path,\n", + " model_max_length=config.max_sequence_length,\n", + " padding_side=\"right\",\n", + " use_fast=False,\n", + " dtype=model_args.dtype\n", + ")\n", + "tokenizer.pad_token = tokenizer.unk_token\n", + "config.update(dict(\n", + " bos_token_id=tokenizer.bos_token_id,\n", + " eos_token_id=tokenizer.eos_token_id,\n", + "))\n", + "\n", + "# TODO(yonghao): don't init weight when loaded somewhere\n", + "dummy_input_shape = (4, config.max_sequence_length)\n", + "# Monkey patch the model's init to init_dummy\n", + "do_monkey_patch()\n", + "model = FlaxLLaMAForCausalLM(config, dummy_input_shape, dtype=dtype)\n", + "hf_model = transformers.AutoModelForCausalLM.from_pretrained(\n", + " model_args.model_name_or_path,\n", + " torch_dtype=torch_dtype\n", + ")\n", + "loaded_params = hf_to_jax_weight(hf_model)\n", + "\n", + "# Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)\n", + "# or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/\n", + "# (the dataset will be downloaded automatically from the datasets Hub).\n", + "#\n", + "# For CSV/JSON files, this script will use the column called 'text' or the first column if no column called\n", + "# 'text' is found. You can easily tweak this behavior (see below).\n", + "#\n", + "# In distributed training, the load_dataset function guarantees that only one local process can concurrently\n", + "# download the dataset.\n", + "data_module = make_supervised_data_module(tokenizer, data_args.dataset_name, IGNORE_TOKEN_ID)\n", + "\n", + "\n", + "if data_args.block_size is None:\n", + " block_size = tokenizer.model_max_length\n", + " if block_size > config.max_position_embeddings:\n", + " logger.warning(\n", + " f\"The tokenizer picked seems to have a very large `model_max_length` ({tokenizer.model_max_length}). \"\n", + " \"Picking 1024 instead. You can change that default value by passing --block_size xxx.\"\n", + " )\n", + " block_size = 1024\n", + "else:\n", + " if data_args.block_size > tokenizer.model_max_length:\n", + " logger.warning(\n", + " f\"The block_size passed ({data_args.block_size}) is larger than the maximum length for the model\"\n", + " f\"({tokenizer.model_max_length}). Using block_size={tokenizer.model_max_length}.\"\n", + " )\n", + " block_size = min(data_args.block_size, tokenizer.model_max_length)\n", + "\n", + "# Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a remainder\n", + "# for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value might be slower\n", + "# to preprocess.\n", + "#\n", + "# To speed up this part, we use multiprocessing. See the documentation of the map method for more information:\n", + "# https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map\n", + "\n", + "logger.info(\"***** Build dataset *****\")\n", + "\n", + "if training_args.do_train:\n", + " if \"train_dataset\" not in data_module:\n", + " raise ValueError(\"--do_train requires a train dataset\")\n", + " train_dataset = data_module[\"train_dataset\"]\n", + " if data_args.max_train_samples is not None:\n", + " max_train_samples = min(len(train_dataset), data_args.max_train_samples)\n", + " train_dataset = train_dataset.select(range(max_train_samples))\n", + "\n", + "if training_args.do_eval:\n", + " if \"eval_dataset\" not in data_module:\n", + " raise ValueError(\"--do_eval requires a validation dataset\")\n", + " eval_dataset = data_module[\"eval_dataset\"]\n", + " if data_args.max_eval_samples is not None:\n", + " max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples)\n", + " eval_dataset = eval_dataset.select(range(max_eval_samples))\n", + "\n", + "# Adjust batch size and num_micro_batches for small datasets\n", + "# num_devices = alpa.get_global_num_devices()\n", + "num_devices = 8\n", + "train_min_batch_size = (num_devices // training_args.operator_parallel //\n", + " training_args.pipeline_parallel * training_args.num_micro_batches)\n", + "eval_num_micro_batches = training_args.num_micro_batches\n", + "eval_min_batch_size = (num_devices // training_args.operator_parallel //\n", + " training_args.pipeline_parallel * eval_num_micro_batches)\n", + "while training_args.do_eval and (len(eval_dataset) < eval_min_batch_size):\n", + " eval_num_micro_batches //= 2\n", + " eval_min_batch_size = (num_devices // training_args.operator_parallel //\n", + " training_args.pipeline_parallel * eval_num_micro_batches)\n", + "\n", + "# Initialize our training\n", + "rng = jax.random.PRNGKey(training_args.seed)\n", + "rng, dropout_rng = jax.random.split(rng)\n", + "\n", + "# Store some constant\n", + "num_epochs = int(training_args.num_train_epochs)\n", + "train_batch_size = int(training_args.per_device_train_batch_size) * num_devices\n", + "eval_batch_size = int(training_args.per_device_eval_batch_size) * num_devices\n", + "steps_per_epoch = len(train_dataset) // train_batch_size\n", + "total_train_steps = steps_per_epoch * num_epochs\n", + "\n", + "# Create learning rate schedule\n", + "cosine_decay_lr_schedule_fn = create_learning_rate_fn(\n", + " len(train_dataset),\n", + " train_batch_size,\n", + " training_args.num_train_epochs,\n", + " training_args.warmup_ratio,\n", + " training_args.learning_rate,\n", + ")\n", + "\n", + "# We use Optax's \"masking\" functionality to not apply weight decay\n", + "# to bias and LayerNorm scale parameters. decay_mask_fn returns a\n", + "# mask boolean with the same structure as the parameters.\n", + "# The mask is True for parameters that should be decayed.\n", + "# Note that this mask is specifically adapted for FlaxGPT2.\n", + "# For other models, one should correct the layer norm parameter naming\n", + "# accordingly.\n", + "def decay_mask_fn(params):\n", + " flat_params = traverse_util.flatten_dict(params)\n", + " flat_mask = {\n", + " path: (path[-1] != \"bias\" and path[-2:] not in [(\"ln_1\", \"scale\"), (\"ln_2\", \"scale\"), (\"ln_f\", \"scale\")])\n", + " for path in flat_params\n", + " }\n", + " return traverse_util.unflatten_dict(flat_mask)\n", + "\n", + "# create adam optimizer\n", + "if training_args.adafactor:\n", + " # We use the default parameters here to initialize adafactor,\n", + " # For more details about the parameters please check https://github.com/deepmind/optax/blob/ed02befef9bf81cbbf236be3d2b0e032e9ed4a40/optax/_src/alias.py#L74\n", + " optimizer = optax.adafactor(\n", + " learning_rate=cosine_decay_lr_schedule_fn,\n", + " )\n", + "else:\n", + " # A tmp hack for llama finetune. Remove it either:\n", + " # 1) rebase to jax 0.4 and use tree_util's mask with path for partition spec;\n", + " # 2) optax fixes the issue of symbolic exec with decay mask fn.\n", + " if training_args.weight_decay == 0.0:\n", + " decay_mask_fn = None\n", + " optimizer = optax.chain(\n", + " optax.clip_by_global_norm(1.0),\n", + " optax.adamw(\n", + " learning_rate=cosine_decay_lr_schedule_fn,\n", + " b1=training_args.adam_beta1,\n", + " b2=training_args.adam_beta2,\n", + " eps=training_args.adam_epsilon,\n", + " weight_decay=training_args.weight_decay,\n", + " mask=decay_mask_fn)\n", + " )\n", + "\n", + "# Setup train state\n", + "if model_args.dtype == \"float16\":\n", + " use_master_copy = True\n", + " dynamic_scale = DynamicScale()\n", + " # Fix a bug in huggingface's implementation (https://github.com/huggingface/transformers/pull/18462)\n", + " alpa.global_config.flax_always_use_fp16_embedding = True\n", + "else:\n", + " use_master_copy = dynamic_scale = None\n", + "# state = TrainState.create(apply_fn=model.__call__, params=loaded_params, tx=optimizer,\n", + "# dynamic_scale=dynamic_scale, use_master_copy=use_master_copy)\n", + "\n", + "# # Manual partition spec\n", + "# state_manual_sharding = llama_manual_sharding(config.num_hidden_layers, state)\n", + "# ms_option = ManualShardingOption(\n", + "# (\"dp\", \"mp\"), in_axis_resources=(state_manual_sharding, PartitionSpec(\"dp\", None)))\n", + "ignore_ids = (IGNORE_TOKEN_ID, tokenizer.pad_token_id)\n", + "\n", + "train_time = 0\n", + "train_metrics = []\n", + "\n", + "step_ct = 0\n", + "last_time = time.time()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "rng, input_rng = jax.random.split(rng)\n", + "train_loader = data_loader(input_rng, train_dataset, 1,\n", + " 1, shuffle=False)\n", + "steps_per_epoch = len(train_dataset) // train_batch_size\n", + "\n", + "batch = next(train_loader)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "position_ids = batch[\"attention_mask\"].cumsum(-1) - 1\n", + "# position_ids = np.select(batch[\"attention_mask\"] == 0, position_ids,\n", + "# np.ones_like(position_ids))\n", + "batch[\"position_ids\"] = position_ids\n", + "\n", + "torch_batch = {}\n", + "for k in batch:\n", + " torch_batch[k] = torch.Tensor(batch[k]).to(torch.device(\"cuda:1\")).long()\n", + "labels = batch.pop(\"labels\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from functools import partial\n", + "@partial(jax.jit, device=jax.devices(\"gpu\")[2])\n", + "def compute_loss(params, batch):\n", + " # Currently we don't support non-deterministic training with remat,\n", + " # so train=False. This arg has no other impact.\n", + " outs = model(**batch, params=params, train=False, output_hidden_states=True)\n", + " logits = outs[0]\n", + " hidden_states = outs[1]\n", + " loss = loss_fn(logits, labels, ignore_ids)\n", + " return logits, loss, hidden_states\n", + "logits, loss, hidden_states = compute_loss(loaded_params, batch)\n", + "print(np.array(loss))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "hf_model = hf_model.to(torch.device(\"cuda:1\"))\n", + "with torch.no_grad():\n", + " hf_out = hf_model(**torch_batch, output_hidden_states=True)\n", + " hf_loss = hf_out.loss.detach().cpu().numpy()\n", + " hf_logits = hf_out.logits.detach().cpu().numpy()\n", + "print(hf_loss)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print(hf_loss, loss)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "jax_hid_0 = np.array(hidden_states[0])\n", + "hf_hid_0 = hf_out.hidden_states[0].detach().cpu().numpy()\n", + "print(np.allclose(jax_hid_0, hf_hid_0))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "jax_hid_1 = np.array(hidden_states[1])\n", + "hf_hid_1 = hf_out.hidden_states[1].detach().cpu().numpy()\n", + "print(np.allclose(jax_hid_1, hf_hid_1))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from EasyLM.models.llama.llama_model import FlaxLLaMABlock\n", + "\n", + "block = FlaxLLaMABlock(config, dtype=dtype)\n", + "block_param = loaded_params['transformer']['h']['0']\n", + "@jax.jit\n", + "def compute(block_param, hidden_state):\n", + " x = block.apply({\"params\": block_param}, hidden_state,\n", + " batch[\"attention_mask\"],\n", + " batch[\"position_ids\"])\n", + " return x\n", + "manual_jax_hid_1 = compute(block_param, hidden_states[0])\n", + "manual_jax_hid_1 = np.array(manual_jax_hid_1)\n", + "print(np.allclose(jax_hid_1, manual_jax_hid_1))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "hf_hid_attention_mask = hf_model.model._prepare_decoder_attention_mask(\n", + " torch_batch[\"attention_mask\"], torch_batch[\"input_ids\"].shape,\n", + " hf_out.hidden_states[0], 0)\n", + "with torch.no_grad():\n", + " manual_hf_hid_1 = hf_model.model.layers[0](\n", + " hf_out.hidden_states[0], hf_hid_attention_mask,\n", + " torch_batch[\"position_ids\"])[0]\n", + "manual_hf_hid_1 = manual_hf_hid_1.detach().cpu().numpy()\n", + "print(np.allclose(hf_hid_1, manual_hf_hid_1))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from EasyLM.models.llama.llama_model import FlaxLLaMAAttention, FlaxLLaMAMLP, RMSNorm\n", + "flax_attn = FlaxLLaMAAttention(config, dtype=dtype, param_dtype=dtype)\n", + "norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, dtype=dtype, param_dtype=dtype)\n", + "ffn = FlaxLLaMAMLP(config, dtype=dtype, param_dtype=dtype)\n", + "\n", + "attn_param = block_param[\"attention\"]\n", + "ffn_param = block_param[\"feed_forward\"]\n", + "attn_norm_param = block_param[\"attention_norm\"]\n", + "ffn_norm_param = block_param[\"ffn_norm\"]\n", + "\n", + "@jax.jit\n", + "def compute(block_param, hidden_state):\n", + " normed_hidden_state = norm.apply({\"params\": block_param[\"attention_norm\"]}, hidden_state)\n", + " attn_outputs = flax_attn.apply({\"params\": block_param[\"attention\"]}, normed_hidden_state,\n", + " attention_mask=batch[\"attention_mask\"],\n", + " position_ids=batch[\"position_ids\"],\n", + " fcm_mask=None\n", + " )\n", + " attn_output = attn_outputs[0]\n", + " hidden_state = hidden_state + attn_output\n", + "\n", + " ffn_normed_hidden = norm.apply({\"params\": block_param[\"ffn_norm\"]}, hidden_state)\n", + " feed_forward_hidden_state = ffn.apply({\"params\": block_param[\"feed_forward\"]},\n", + " ffn_normed_hidden,\n", + " deterministic=True,\n", + " )\n", + " hidden_state = hidden_state + feed_forward_hidden_state\n", + " return hidden_state, (normed_hidden_state, attn_outputs[0], ffn_normed_hidden, feed_forward_hidden_state)\n", + "split_manual_jax_hid_1, (normed_hidden, attn_out, ffn_normed, ffn_out) = compute(block_param, hidden_states[0])\n", + "split_manual_jax_hid_1 = np.array(split_manual_jax_hid_1)\n", + "print(np.allclose(jax_hid_1, split_manual_jax_hid_1))\n", + "print(np.allclose(manual_jax_hid_1, split_manual_jax_hid_1))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print(np.allclose(split_manual_jax_hid_1, hf_hid_1))\n", + "# print(split_manual_jax_hid_1)\n", + "# print(manual_jax_hid_1)\n", + "# print(hf_hid_1)\n", + "print(jax_hid_1.dtype)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "hf_block = hf_model.model.layers[0]\n", + "def hf_compute(hidden_states, attention_mask, position_ids):\n", + " with torch.no_grad():\n", + " residual = hidden_states\n", + " attn_normed = hf_block.input_layernorm(hidden_states)\n", + "\n", + " # Self Attention\n", + " attn_out, _, _ = hf_block.self_attn(\n", + " hidden_states=attn_normed,\n", + " attention_mask=attention_mask,\n", + " position_ids=position_ids,\n", + " )\n", + " hidden_states = residual + attn_out\n", + "\n", + " # Fully Connected\n", + " residual = hidden_states\n", + " ffn_normed = hf_block.post_attention_layernorm(hidden_states)\n", + " ffn_out = hf_block.mlp(ffn_normed)\n", + " hidden_states = residual + ffn_out\n", + " return hidden_states.detach().cpu().numpy(), (attn_normed, attn_out, ffn_normed, ffn_out)\n", + "\n", + "\n", + "manual_hf_hid_1, (hf_attn_normed, hf_attn_out, hf_ffn_normed,\n", + " hf_ffn_out) = hf_compute(hf_out.hidden_states[0],\n", + " hf_hid_attention_mask,\n", + " torch_batch[\"position_ids\"])\n", + "print(np.allclose(hf_hid_1, manual_hf_hid_1))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print(np.allclose(np.array(normed_hidden), hf_attn_normed.detach().cpu().numpy()))\n", + "print(np.allclose(np.array(attn_out), hf_attn_out.detach().cpu().numpy(), atol=1e-7))\n", + "print(np.allclose(np.array(ffn_normed), hf_ffn_normed.detach().cpu().numpy(), atol=1e-6))\n", + "print(np.allclose(np.array(ffn_out), hf_ffn_out.detach().cpu().numpy(), atol=1e-6))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import flax.linen as nn\n", + "dense_1 = nn.Dense(\n", + " config.intermediate_size,\n", + " dtype=dtype,\n", + " use_bias=False,\n", + " kernel_init=jax.nn.initializers.normal(config.initializer_range),\n", + ")\n", + "dense_2 = nn.Dense(\n", + " config.hidden_size,\n", + " dtype=dtype,\n", + " use_bias=False,\n", + " kernel_init=jax.nn.initializers.normal(config.initializer_range),\n", + ")\n", + "dense_3 = nn.Dense(\n", + " config.intermediate_size,\n", + " dtype=dtype,\n", + " use_bias=False,\n", + " kernel_init=jax.nn.initializers.normal(config.initializer_range),\n", + ")\n", + "mlp_block = hf_block.mlp\n", + "\n", + "@jax.jit\n", + "def compute(x):\n", + " gate = dense_1.apply({\"params\": ffn_param[\"w1\"]}, x)\n", + " up = dense_3.apply({\"params\": ffn_param[\"w3\"]}, x)\n", + " x = dense_2.apply({\"params\": ffn_param[\"w2\"]}, nn.silu(gate) * up)\n", + " return x, gate, up\n", + "def torch_compute(x):\n", + " with torch.no_grad():\n", + " gate = mlp_block.gate_proj(x)\n", + " up = mlp_block.up_proj(x)\n", + " x = mlp_block.down_proj(mlp_block.act_fn(gate) * up)\n", + " return x, gate, up\n", + "\n", + "hf_ffn_in = hf_ffn_normed.detach().cpu().numpy()\n", + "print(np.allclose(ffn_param[\"w1\"][\"kernel\"], mlp_block.gate_proj.weight.detach().cpu().numpy().transpose()))\n", + "print(np.allclose(ffn_param[\"w2\"][\"kernel\"], mlp_block.down_proj.weight.detach().cpu().numpy().transpose()))\n", + "print(np.allclose(ffn_param[\"w3\"][\"kernel\"], mlp_block.up_proj.weight.detach().cpu().numpy().transpose()))\n", + "\n", + "\n", + "ffn_out_hf_input, ffn_gate, ffn_up = [np.array(x) for x in compute(hf_ffn_in)]\n", + "hf_ffn_out, hf_gate, hf_up = torch_compute(hf_ffn_normed)\n", + "\n", + "print(np.allclose(hf_ffn_out.detach().cpu().numpy(), mlp_block(hf_ffn_normed).detach().cpu().numpy()))\n", + "print(np.allclose(hf_gate.detach().cpu().numpy(), ffn_gate, atol=1e-6))\n", + "print(np.allclose(hf_up.detach().cpu().numpy(), ffn_up, atol=1e-6))\n", + "print(np.allclose(ffn_out_hf_input, hf_block.mlp(hf_ffn_normed).detach().cpu().numpy()))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print(hf_gate.detach().cpu().numpy())\n", + "print(ffn_gate)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "attn_param[\"wq\"]\n", + "attn_param[\"wk\"]\n", + "attn_param[\"wv\"]\n", + "attn_param[\"wo\"]" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.10" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +}