Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to run Bert with DTR? #7

Closed
LiuXiaoxuanPKU opened this issue Jan 15, 2022 · 6 comments
Closed

How to run Bert with DTR? #7

LiuXiaoxuanPKU opened this issue Jan 15, 2022 · 6 comments

Comments

@LiuXiaoxuanPKU
Copy link

LiuXiaoxuanPKU commented Jan 15, 2022

Hi,

Thanks for the repo! I am playing with DTR and trying to run Bert with DTR. Following the experiment code in the repo, I made the following modifications to my training scripts

model = ...
mem_budget = 6*1024**3
torch.set_memory_budget(mem_budget)
model._apply(lambda v: v.detach().checkpoint())

for step, batch in enumerate(train_dataloader):
  for k, v in batch.items():
    batch[k] = v.checkpoint()

  outputs = model(**batch)
  loss = outputs.loss
  optimizer.zero_grad()
  loss.backward()

  for k, v in batch.items():
    batch[k] = v.decheckpoint()
    loss = loss.decheckpoint()

I run the code with DTR modified PyTorch and it throws the segfault, any comments on this are hight appreciated!

@MarisaKirisame
Copy link
Collaborator

Hi, can you provide a minimal reproducible example script? it is pretty hard to tell what happend from above.

@LiuXiaoxuanPKU
Copy link
Author

LiuXiaoxuanPKU commented Jan 15, 2022

I try to reduce the training script, does the following look good?

""" Finetuning a 🤗 Transformers model for sequence classification on GLUE."""
import argparse
import logging
import math
import os
import random
import torch

import datasets
from datasets import load_dataset, load_metric
from torch.utils.data import DataLoader
from tqdm.auto import tqdm

import transformers
from huggingface_hub import Repository
from transformers import (
    AdamW,
    AutoConfig,
    AutoModelForSequenceClassification,
    AutoTokenizer,
    DataCollatorWithPadding,
    PretrainedConfig,
    SchedulerType,
    default_data_collator,
    get_scheduler,
    set_seed,
)


def parse_args():
    parser = argparse.ArgumentParser(description="Finetune a transformers model on a text classification task")
    parser.add_argument(
        "--model_name_or_path",
        type=str,
        help="Path to pretrained model or model identifier from huggingface.co/models.",
        required=True,
    )
    parser.add_argument(
        "--per_device_train_batch_size",
        type=int,
        default=8,
        help="Batch size (per device) for the training dataloader.",
    )
    parser.add_argument(
        "--learning_rate",
        type=float,
        default=5e-5,
        help="Initial learning rate (after the potential warmup period) to use.",
    )
    parser.add_argument("--weight_decay", type=float, default=0.01, help="Weight decay to use.")
    parser.add_argument("--num_train_epochs", type=int, default=3, help="Total number of training epochs to perform.")
    parser.add_argument(
        "--max_train_steps",
        type=int,
        default=None,
        help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
    )
    parser.add_argument(
        "--max_gradient_norm",
        type=float,
        default=1.,
        help="Maximum norm of gradient.",
    )
    parser.add_argument(
        "--lr_scheduler_type",
        type=SchedulerType,
        default="linear",
        help="The scheduler type to use.",
        choices=["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"],
    )
    parser.add_argument("--use_dtr", action='store_true')
    
    args = parser.parse_args()

    return args


def main():
    task_name = 'sst2'
    max_length = 128
    pad_to_max_length = True
    gradient_accumulation_steps = 1
    args = parse_args()
    use_dtr = args.use_dtr

    args.device = 'cuda' if torch.cuda.is_available() else 'cpu'

    # Get the datasets
    # Downloading and loading a dataset from the hub.
    raw_datasets = load_dataset("glue", task_name)
    # Labels
    label_list = raw_datasets["train"].features["label"].names
    num_labels = len(label_list)

    # Load pretrained model and tokenizer
    #
    # In distributed training, the .from_pretrained methods guarantee that only one local process can concurrently
    # download model & vocab.
    tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=True)
    config = AutoConfig.from_pretrained(args.model_name_or_path, num_labels=num_labels, finetuning_task=task_name)
    model = AutoModelForSequenceClassification.from_pretrained(
            args.model_name_or_path,
            from_tf=bool(".ckpt" in args.model_name_or_path),
            config=config,
    )
    model.to(args.device)

    if use_dtr:
        torch.set_memory_budget(10*1024**3)
        model._apply(lambda v: v.detach().checkpoint())
        
    # Preprocessing the datasets
    sentence1_key, sentence2_key = "sentence", None

    # Some models have set the order of the labels to use, so let's make sure we do use it.
    label_to_id = None
    if (
        model.config.label2id != PretrainedConfig(num_labels=num_labels).label2id
        and task_name is not None
    ):
        # Some have all caps in their config, some don't.
        label_name_to_id = {k.lower(): v for k, v in model.config.label2id.items()}
        if list(sorted(label_name_to_id.keys())) == list(sorted(label_list)):
            logger.info(
                f"The configuration of the model provided the following label correspondence: {label_name_to_id}. "
                "Using it!"
            )
            label_to_id = {i: label_name_to_id[label_list[i]] for i in range(num_labels)}
        else:
            logger.warning(
                "Your model seems to have been trained with labels, but they don't match the dataset: ",
                f"model labels: {list(sorted(label_name_to_id.keys()))}, dataset labels: {list(sorted(label_list))}."
                "\nIgnoring the model labels as a result.",
            )
    elif task_name is None:
        label_to_id = {v: i for i, v in enumerate(label_list)}

    if label_to_id is not None:
        model.config.label2id = label_to_id
        model.config.id2label = {id: label for label, id in config.label2id.items()}
    
    model.config.label2id = {l: i for i, l in enumerate(label_list)}
    model.config.id2label = {id: label for label, id in config.label2id.items()}

    padding = "max_length"

    def preprocess_function(examples):
        # Tokenize the texts
        texts = (
            (examples[sentence1_key],) if sentence2_key is None else (examples[sentence1_key], examples[sentence2_key])
        )
        result = tokenizer(*texts, padding=padding, max_length=max_length, truncation=True)

        if "label" in examples:
            if label_to_id is not None:
                # Map labels to IDs (not necessary for GLUE tasks)
                result["labels"] = [label_to_id[l] for l in examples["label"]]
            else:
                # In all cases, rename the column to labels because the model will expect that.
                result["labels"] = examples["label"]
        return result

    processed_datasets = raw_datasets.map(
            preprocess_function,
            batched=True,
            remove_columns=raw_datasets["train"].column_names,
            desc="Running tokenizer on dataset",
        )
    train_dataset = processed_datasets["train"]

    train_max_length = 0
    for item in train_dataset:
        if len(item['input_ids']) > train_max_length:
            train_max_length = len(item['input_ids'])

    # DataLoaders creation:
    train_dataloader = DataLoader(
        train_dataset, shuffle=True, collate_fn=default_data_collator, batch_size=args.per_device_train_batch_size
    )
    
    # Optimizer
    # Split weights in two groups, one with weight decay and the other not.
    no_decay = ["bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [
        {
            "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
            "weight_decay": args.weight_decay,
        },
        {
            "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
            "weight_decay": 0.0,
        },
    ]
    optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, weight_decay=args.weight_decay)


    # Scheduler and math around the number of training steps.
    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps)
    if args.max_train_steps is None:
        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
    else:
        args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)

    lr_scheduler = get_scheduler(
        name=args.lr_scheduler_type,
        optimizer=optimizer,
        num_warmup_steps=0,
        num_training_steps=args.max_train_steps,
    )

    # Get the metric function
    metric = load_metric("glue", task_name)

    # Train!
    total_batch_size = args.per_device_train_batch_size

    progress_bar = tqdm(range(args.max_train_steps))
    completed_steps = 0

    iter = 0
    best_metric = 0
    for epoch in range(args.num_train_epochs):
        model.train()
        for step, batch in enumerate(train_dataloader):
            iter += 1
            for k, v in batch.items():
                batch[k] = v.to(args.device)
            
            if use_dtr:
                for k, v in batch.items():
                    batch[k] = v.checkpoint()
        
            outputs = model(**batch)
            loss = outputs.loss
            optimizer.zero_grad()
            loss.backward()
            
            if use_dtr:
                for k, v in batch.items():
                    batch[k] = v.decheckpoint()
                loss = loss.decheckpoint()

            torch.nn.utils.clip_grad.clip_grad_norm_(model.parameters(), args.max_gradient_norm)
            optimizer.step()
            lr_scheduler.step()
            progress_bar.update(1)
            completed_steps += 1

            if completed_steps >= args.max_train_steps:
                break

if __name__ == "__main__":
    main()

The running cmd is

python run_glue_dtr_minimal.py \
  --model_name_or_path bert-large-cased \
  --per_device_train_batch_size 32 \
  --learning_rate 1e-5 \
  --num_train_epochs 5 \
  --use_dtr

Please let me know if you can reproduce, thanks a lot!

@AD1024
Copy link
Member

AD1024 commented Jan 16, 2022

I try to reduce the training script, does the following look good?

""" Finetuning a 🤗 Transformers model for sequence classification on GLUE."""
import argparse
import logging
import math
import os
import random
import torch

import datasets
from datasets import load_dataset, load_metric
from torch.utils.data import DataLoader
from tqdm.auto import tqdm

import transformers
from huggingface_hub import Repository
from transformers import (
    AdamW,
    AutoConfig,
    AutoModelForSequenceClassification,
    AutoTokenizer,
    DataCollatorWithPadding,
    PretrainedConfig,
    SchedulerType,
    default_data_collator,
    get_scheduler,
    set_seed,
)


def parse_args():
    parser = argparse.ArgumentParser(description="Finetune a transformers model on a text classification task")
    parser.add_argument(
        "--model_name_or_path",
        type=str,
        help="Path to pretrained model or model identifier from huggingface.co/models.",
        required=True,
    )
    parser.add_argument(
        "--per_device_train_batch_size",
        type=int,
        default=8,
        help="Batch size (per device) for the training dataloader.",
    )
    parser.add_argument(
        "--learning_rate",
        type=float,
        default=5e-5,
        help="Initial learning rate (after the potential warmup period) to use.",
    )
    parser.add_argument("--weight_decay", type=float, default=0.01, help="Weight decay to use.")
    parser.add_argument("--num_train_epochs", type=int, default=3, help="Total number of training epochs to perform.")
    parser.add_argument(
        "--max_train_steps",
        type=int,
        default=None,
        help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
    )
    parser.add_argument(
        "--max_gradient_norm",
        type=float,
        default=1.,
        help="Maximum norm of gradient.",
    )
    parser.add_argument(
        "--lr_scheduler_type",
        type=SchedulerType,
        default="linear",
        help="The scheduler type to use.",
        choices=["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"],
    )
    parser.add_argument("--use_dtr", action='store_true')
    
    args = parser.parse_args()

    return args


def main():
    task_name = 'sst2'
    max_length = 128
    pad_to_max_length = True
    gradient_accumulation_steps = 1
    args = parse_args()
    use_dtr = args.use_dtr

    args.device = 'cuda' if torch.cuda.is_available() else 'cpu'

    # Get the datasets
    # Downloading and loading a dataset from the hub.
    raw_datasets = load_dataset("glue", task_name)
    # Labels
    label_list = raw_datasets["train"].features["label"].names
    num_labels = len(label_list)

    # Load pretrained model and tokenizer
    #
    # In distributed training, the .from_pretrained methods guarantee that only one local process can concurrently
    # download model & vocab.
    tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=True)
    config = AutoConfig.from_pretrained(args.model_name_or_path, num_labels=num_labels, finetuning_task=task_name)
    model = AutoModelForSequenceClassification.from_pretrained(
            args.model_name_or_path,
            from_tf=bool(".ckpt" in args.model_name_or_path),
            config=config,
    )
    model.to(args.device)

    if use_dtr:
        torch.set_memory_budget(10*1024**3)
        model._apply(lambda v: v.detach().checkpoint())
        
    # Preprocessing the datasets
    sentence1_key, sentence2_key = "sentence", None

    # Some models have set the order of the labels to use, so let's make sure we do use it.
    label_to_id = None
    if (
        model.config.label2id != PretrainedConfig(num_labels=num_labels).label2id
        and task_name is not None
    ):
        # Some have all caps in their config, some don't.
        label_name_to_id = {k.lower(): v for k, v in model.config.label2id.items()}
        if list(sorted(label_name_to_id.keys())) == list(sorted(label_list)):
            logger.info(
                f"The configuration of the model provided the following label correspondence: {label_name_to_id}. "
                "Using it!"
            )
            label_to_id = {i: label_name_to_id[label_list[i]] for i in range(num_labels)}
        else:
            logger.warning(
                "Your model seems to have been trained with labels, but they don't match the dataset: ",
                f"model labels: {list(sorted(label_name_to_id.keys()))}, dataset labels: {list(sorted(label_list))}."
                "\nIgnoring the model labels as a result.",
            )
    elif task_name is None:
        label_to_id = {v: i for i, v in enumerate(label_list)}

    if label_to_id is not None:
        model.config.label2id = label_to_id
        model.config.id2label = {id: label for label, id in config.label2id.items()}
    
    model.config.label2id = {l: i for i, l in enumerate(label_list)}
    model.config.id2label = {id: label for label, id in config.label2id.items()}

    padding = "max_length"

    def preprocess_function(examples):
        # Tokenize the texts
        texts = (
            (examples[sentence1_key],) if sentence2_key is None else (examples[sentence1_key], examples[sentence2_key])
        )
        result = tokenizer(*texts, padding=padding, max_length=max_length, truncation=True)

        if "label" in examples:
            if label_to_id is not None:
                # Map labels to IDs (not necessary for GLUE tasks)
                result["labels"] = [label_to_id[l] for l in examples["label"]]
            else:
                # In all cases, rename the column to labels because the model will expect that.
                result["labels"] = examples["label"]
        return result

    processed_datasets = raw_datasets.map(
            preprocess_function,
            batched=True,
            remove_columns=raw_datasets["train"].column_names,
            desc="Running tokenizer on dataset",
        )
    train_dataset = processed_datasets["train"]

    train_max_length = 0
    for item in train_dataset:
        if len(item['input_ids']) > train_max_length:
            train_max_length = len(item['input_ids'])

    # DataLoaders creation:
    train_dataloader = DataLoader(
        train_dataset, shuffle=True, collate_fn=default_data_collator, batch_size=args.per_device_train_batch_size
    )
    
    # Optimizer
    # Split weights in two groups, one with weight decay and the other not.
    no_decay = ["bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [
        {
            "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
            "weight_decay": args.weight_decay,
        },
        {
            "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
            "weight_decay": 0.0,
        },
    ]
    optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, weight_decay=args.weight_decay)


    # Scheduler and math around the number of training steps.
    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps)
    if args.max_train_steps is None:
        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
    else:
        args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)

    lr_scheduler = get_scheduler(
        name=args.lr_scheduler_type,
        optimizer=optimizer,
        num_warmup_steps=0,
        num_training_steps=args.max_train_steps,
    )

    # Get the metric function
    metric = load_metric("glue", task_name)

    # Train!
    total_batch_size = args.per_device_train_batch_size

    progress_bar = tqdm(range(args.max_train_steps))
    completed_steps = 0

    iter = 0
    best_metric = 0
    for epoch in range(args.num_train_epochs):
        model.train()
        for step, batch in enumerate(train_dataloader):
            iter += 1
            for k, v in batch.items():
                batch[k] = v.to(args.device)
            
            if use_dtr:
                for k, v in batch.items():
                    batch[k] = v.checkpoint()
        
            outputs = model(**batch)
            loss = outputs.loss
            optimizer.zero_grad()
            loss.backward()
            
            if use_dtr:
                for k, v in batch.items():
                    batch[k] = v.decheckpoint()
                loss = loss.decheckpoint()

            torch.nn.utils.clip_grad.clip_grad_norm_(model.parameters(), args.max_gradient_norm)
            optimizer.step()
            lr_scheduler.step()
            progress_bar.update(1)
            completed_steps += 1

            if completed_steps >= args.max_train_steps:
                break

if __name__ == "__main__":
    main()

The running cmd is

python run_glue_dtr_minimal.py \
  --model_name_or_path bert-large-cased \
  --per_device_train_batch_size 32 \
  --learning_rate 1e-5 \
  --num_train_epochs 5 \
  --use_dtr

Please let me know if you can reproduce, thanks a lot!

Hi Xiaoxuan, thank you for providing this example! I will have a look at it during the weekend.

@LiuXiaoxuanPKU
Copy link
Author

Hi, just check in, any updates on this? Please let me know if you can not reproduce this, thanks!

@MarisaKirisame
Copy link
Collaborator

we are able to reproduce it. we are debugging on it rn - you can also just fire a gdb on it to see what's wrong.

@MarisaKirisame
Copy link
Collaborator

Hey, sorry for the long reply! The work is done a while ago but I forgot to mention you. it is at uwsampl/pytorch#71. let me know if there is anything i can do to help.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants