Skip to content

Commit

Permalink
subwordXLM, wip
Browse files Browse the repository at this point in the history
  • Loading branch information
markus583 committed Dec 21, 2023
1 parent 3ae92ea commit 775d4a6
Show file tree
Hide file tree
Showing 8 changed files with 311 additions and 62 deletions.
4 changes: 2 additions & 2 deletions configs/canine_stratify_0.1_12layers_long.json
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@
"do_train": true,
"do_eval": true,
"evaluation_strategy": "steps",
"per_device_train_batch_size": 64,
"per_device_train_batch_size": 32,
"per_device_eval_batch_size": 32,
"gradient_accumulation_steps": 1,
"gradient_accumulation_steps": 2,
"eval_accumulation_steps": 8,
"dataloader_num_workers": 4,
"preprocessing_num_workers": 6,
Expand Down
41 changes: 41 additions & 0 deletions configs/xlmr_stratify_0.1_3layers.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
{
"model_name_or_path": "xlm-roberta-base",
"output_dir": "xlmr-TEST",
"train_text_path": "data/sentence/train.parquet",
"valid_text_path": "data/sentence/valid.parquet",
"block_size": 512,
"use_bert": true,
"do_train": true,
"do_eval": true,
"evaluation_strategy": "steps",
"per_device_train_batch_size": 64,
"per_device_eval_batch_size": 32,
"gradient_accumulation_steps": 1,
"eval_accumulation_steps": 8,
"dataloader_num_workers": 32,
"preprocessing_num_workers": 6,
"learning_rate": 1e-4,
"save_strategy": "steps",
"fp16": false,
"max_steps": 2000000,
"save_steps": 100000,
"eval_steps": 50000000000,
"logging_steps": 50,
"report_to": "wandb",
"is_decoder": false,
"remove_unused_columns": false,
"lookahead": null,
"one_sample_per_line": false,
"do_sentence_training": true,
"do_auxiliary_training": true,
"warmup_steps": 5000,
"adapter_warmup_steps": 0,
"adapter_lr_multiplier": 1,
"ngram_order": 1,
"non_punctuation_sample_ratio": 0.1,
"prediction_loss_only": true,
"use_auxiliary": true,
"ddp_timeout": 3600,
"use_subwords": true,
"num_hidden_layers": 3
}
1 change: 1 addition & 0 deletions run.sh
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
# TODO: cleanup in case of no .arrow files but cache-* files available.
python3 ~/wtpsplit/xla_spawn.py --num_cores ${TPU_NUM_DEVICES} wtpsplit/train/train.py $1
4 changes: 4 additions & 0 deletions tpu_starter.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
for var in "$@"
do
until gcloud compute tpus tpu-vm create $var --zone=europe-west4-a --accelerator-type=v3-8 --version=tpu-vm-base; do sleep 5; done
done
17 changes: 16 additions & 1 deletion wtpsplit/configs.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from transformers import AutoConfig, BertConfig, CanineConfig
from transformers import AutoConfig, BertConfig, CanineConfig, XLMRobertaConfig


class LACanineConfig(CanineConfig):
Expand Down Expand Up @@ -37,7 +37,22 @@ def __init__(

self.num_hash_buckets = num_hash_buckets
self.num_hash_functions = num_hash_functions

class SubwordXLMConfig(XLMRobertaConfig):
"""Config for XLM-R and XLM-V models. Used for token-level training.
Args:
XLMRobertaConfig: Base class.
"""
model_type = "xlm-token"

def __init__(
self,
**kwargs,
):
super().__init__(**kwargs)


AutoConfig.register("bert-char", BertCharConfig)
AutoConfig.register("la-canine", LACanineConfig)
AutoConfig.register("xlm-token", SubwordXLMConfig)
79 changes: 78 additions & 1 deletion wtpsplit/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from torch import nn
from transformers import AutoModel, AutoModelForTokenClassification
from transformers.models.bert.modeling_bert import BertEncoder, BertForTokenClassification, BertModel, BertPooler
from transformers.models.xlm_roberta import XLMRobertaModel, XLMRobertaForTokenClassification
from transformers.models.canine.modeling_canine import (
_PRIMES,
ACT2FN,
Expand All @@ -26,8 +27,9 @@
ConvProjection,
TokenClassifierOutput,
)
from torchinfo import summary

from wtpsplit.configs import BertCharConfig, LACanineConfig
from wtpsplit.configs import BertCharConfig, LACanineConfig, SubwordXLMConfig
from wtpsplit.utils import Constants


Expand Down Expand Up @@ -956,9 +958,84 @@ def forward(
return_dict,
)

class SubwordXLMForTokenClassification(XLMRobertaForTokenClassification):
config_class = SubwordXLMConfig

_keys_to_ignore_on_load_unexpected = [r"pooler"]
_keys_to_ignore_on_load_missing = [r"position_ids"]

def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels

self.roberta = XLMRobertaModel(config, add_pooling_layer=False)
classifier_dropout = (
config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
)
self.dropout = nn.Dropout(classifier_dropout)
self.classifier = nn.Linear(config.hidden_size, config.num_labels)

# Initialize weights and apply final processing
self.post_init()

def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
language_ids=None,
) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:
return super().forward(
input_ids,
attention_mask,
token_type_ids,
position_ids,
head_mask,
inputs_embeds,
labels,
output_attentions,
output_hidden_states,
return_dict,
)



AutoModel.register(LACanineConfig, LACanineModel)
AutoModelForTokenClassification.register(LACanineConfig, LACanineForTokenClassification)

AutoModel.register(BertCharConfig, BertCharModel)
AutoModelForTokenClassification.register(BertCharConfig, BertCharForTokenClassification)

AutoModel.register(SubwordXLMConfig, SubwordXLMForTokenClassification)
AutoModelForTokenClassification.register(SubwordXLMConfig, SubwordXLMForTokenClassification)

if __name__ == "__main__":
# test XLM
from transformers import AutoConfig, AutoTokenizer
model_str = "xlm-roberta-base"
config = AutoConfig.from_pretrained(model_str)
config.num_labels = 4
config.num_hidden_layers = 9
backbone = SubwordXLMForTokenClassification.from_pretrained(model_str, config=config)
print(summary(backbone, depth=4))

# some sample input
text = "This is a test\n sentence \n\n"
tokenizer = AutoTokenizer.from_pretrained(model_str)

tokens = tokenizer(text, return_tensors="pt")
from tokenizers import AddedToken
tokenizer.add_special_tokens({"additional_special_tokens": [AddedToken("\n")]})
print(tokenizer.tokenize(text))
print(tokenizer.encode(text))
print(tokens)
# forward pass
print(backbone(**tokens))

Loading

0 comments on commit 775d4a6

Please sign in to comment.