From 8a8987f3cd90f650e286d9b15998244149811f77 Mon Sep 17 00:00:00 2001 From: Vincent Nguyen Date: Fri, 3 Jan 2025 22:39:33 +0100 Subject: [PATCH] Code clean-ups (#171) * misc optimization * clear cache after translation during scoring * allow more recompiles * set rope / position_embeddings at model build * remove BPTT * clarify pad_mask(true=yes we pad, so we won't attend) and attn_mask(true=yes we attend) * preallocate KV cache even in "pytorch" path (same as flash) * reduce config updates --- .github/workflows/push.yml | 22 +-- README.md | 21 +- eole/bin/convert/convert_HF.py | 9 +- eole/bin/run/serve.py | 2 +- eole/config/models.py | 70 +++++-- eole/config/run.py | 14 +- eole/config/training.py | 4 - eole/decoders/ensemble.py | 11 +- eole/decoders/rnn_decoder.py | 4 +- eole/decoders/transformer_base.py | 20 +- eole/decoders/transformer_decoder.py | 79 +++----- eole/decoders/transformer_lm_decoder.py | 64 +++--- eole/encoders/cnn_encoder.py | 2 +- eole/encoders/encoder.py | 7 +- eole/encoders/mean_encoder.py | 10 +- eole/encoders/rnn_encoder.py | 2 +- eole/encoders/transformer.py | 33 ++-- eole/inputters/text_utils.py | 12 +- eole/models/model.py | 99 ++++++---- eole/modules/multi_headed_attn.py | 186 ++++++++---------- eole/modules/rope.py | 40 ++-- eole/predict/__init__.py | 2 +- eole/predict/encoder.py | 9 +- eole/predict/generator.py | 5 +- eole/predict/inference.py | 12 +- eole/predict/translator.py | 29 ++- eole/tests/data/data_lm/gen-beam-sol.txt | 2 +- .../data/data_lm/gen-sampling-beams-sol2.txt | 14 +- eole/tests/data/data_lm/gen-sampling-sol.txt | 6 +- eole/tests/pull_request_check.sh | 174 ++++++++-------- eole/tests/test_events.py | 2 +- eole/tests/test_model.yml | 1 - eole/tests/test_model_lm.yml | 1 - eole/tests/test_model_lm/config.json | 2 +- eole/tests/test_models.py | 16 +- eole/train_single.py | 8 +- eole/trainer.py | 113 ++++------- eole/utils/loss.py | 25 +-- eole/utils/misc.py | 2 +- eole/utils/optimizers.py | 8 +- recipes/cometkiwi/cometkiwi-xl-eole.yaml | 4 - recipes/cometkiwi/cometkiwi-xxl-eole.yaml | 4 - 42 files changed, 562 insertions(+), 588 deletions(-) diff --git a/.github/workflows/push.yml b/.github/workflows/push.yml index 60483644..1e1d7f7c 100644 --- a/.github/workflows/push.yml +++ b/.github/workflows/push.yml @@ -48,7 +48,7 @@ jobs: -src_vocab /tmp/eole.vocab.src \ -tgt_vocab /tmp/eole.vocab.tgt \ && rm -rf /tmp/sample - - name: Test field/transform dump + - name: Testing architecture rnn sample dump... run: | # The dumped fields are used later when testing tools python eole/bin/main.py train \ @@ -61,7 +61,7 @@ jobs: -tgt_vocab /tmp/eole.vocab.tgt \ -src_vocab_size 1000 \ -tgt_vocab_size 1000 - - name: Test RNN training + - name: Testing architecture rnn training run: | python eole/bin/main.py train \ -config eole/tests/data/data.yaml \ @@ -75,7 +75,7 @@ jobs: -tensorboard \ -tensorboard_log_dir /tmp/logs_train python eole/tests/test_events.py --logdir /tmp/logs_train -tensorboard_checks train - - name: Test RNN training and validation + - name: Testing architecture rnn training and validation run: | python eole/bin/main.py train \ -config eole/tests/data/data.yaml \ @@ -90,7 +90,7 @@ jobs: -tensorboard_log_dir /tmp/logs_train_and_valid python eole/tests/test_events.py --logdir /tmp/logs_train_and_valid -tensorboard_checks train python eole/tests/test_events.py --logdir /tmp/logs_train_and_valid -tensorboard_checks valid - - name: Test RNN training with coverage + - name: Testing architecture rnn training w/ coverage run: | python eole/bin/main.py train \ -config eole/tests/data/data.yaml \ @@ -101,7 +101,7 @@ jobs: -report_every 5 \ -model '{"architecture": "rnn", "hidden_size": 10, "embeddings": {"word_vec_size": 5, "position_encoding_type": None}, "decoder": {"coverage_attn": True, "lambda_coverage": 0.1}}' \ -training '{"batch_size": 10, "num_workers": 0, "bucket_size": 1024, "train_steps": 10}' - - name: Test Transformer training with align + - name: Testing architecture custom transformer training w/ align run: | python eole/bin/main.py train \ -config eole/tests/data/align_data.yaml \ @@ -112,7 +112,7 @@ jobs: -model '{"layers": 4, "hidden_size": 16, "transformer_ff": 64, "embeddings": {"word_vec_size": 16}, "encoder": {"encoder_type": "transformer", "heads": 2}, "decoder": {"decoder_type": "transformer", "lambda_align": 0.05, "alignment_layer": 2, "alignment_heads": 0, "heads": 2}}' \ -training '{"batch_size": 10, "num_workers": 0, "bucket_size": 1024, "train_steps": 10, "dropout_steps": [0, 3, 7], "dropout": [0.3, 0.2, 0.1], "attention_dropout": [0.2, 0.2, 0.1]}' \ -report_every 5 \ - - name : Test Transformer training and validation with dynamic scoring + - name : Testing architecture custom transformer training w/ validation with dynamic scoring run: | python3 eole/bin/main.py train \ -config eole/tests/data/data.yaml \ @@ -129,7 +129,7 @@ jobs: -tensorboard_log_dir /tmp/logs_dynamic-scoring \ -dump_preds /tmp/dump_preds python eole/tests/test_events.py --logdir /tmp/logs_dynamic-scoring -tensorboard_checks valid_metrics - - name : Test Transformer training and validation with dynamic scoring and maxrelative + - name : Testing architecture transformer training w/ validation with dynamic scoring and maxrelative run: | python3 eole/bin/main.py train \ -config eole/tests/data/data.yaml \ @@ -146,7 +146,7 @@ jobs: -tensorboard_log_dir /tmp/logs_dynamic-scoring_and_relative \ -dump_preds /tmp/dump_preds python eole/tests/test_events.py --logdir /tmp/logs_dynamic-scoring_and_relative -tensorboard_checks valid_metrics - - name : Test Transformer training and validation with dynamic scoring and rotary + - name : Testing architecture transformer training w/ validation with dynamic scoring and rotary run: | python3 eole/bin/main.py train \ -config eole/tests/data/data.yaml \ @@ -154,7 +154,7 @@ jobs: -tgt_vocab /tmp/eole.vocab.tgt \ -src_vocab_size 1000 \ -tgt_vocab_size 1000 \ - -model '{"architecture": "transformer", "layers": 4, "heads": 2, "hidden_size": 16, "transformer_ff": 64, "embeddings": {"word_vec_size": 16, "position_encoding_type": "Rotary"}}' \ + -model '{"architecture": "transformer", "layers": 4, "heads": 2, "hidden_size": 16, "transformer_ff": 64, "rope_config": {}, "embeddings": {"word_vec_size": 16, "position_encoding_type": "Rotary"}}' \ -training '{"batch_size": 10, "num_workers": 0, "bucket_size": 1024, "train_steps": 10, "valid_steps": 5}' \ -report_every 2 \ -valid_metrics "BLEU" "TER" \ @@ -163,7 +163,7 @@ jobs: -tensorboard_log_dir /tmp/logs_dynamic-scoring_and_rotary \ -dump_preds /tmp/dump_preds python eole/tests/test_events.py --logdir /tmp/logs_dynamic-scoring_and_rotary -tensorboard_checks valid_metrics - - name : Test Transformer training and validation with dynamic scoring and alibi + - name : Testing architecture transformer training w/ validation with dynamic scoring and alibi run: | python3 eole/bin/main.py train \ -config eole/tests/data/data.yaml \ @@ -180,7 +180,7 @@ jobs: -tensorboard_log_dir /tmp/logs_dynamic-scoring_and_alibi \ -dump_preds /tmp/dump_preds python eole/tests/test_events.py --logdir /tmp/logs_dynamic-scoring_and_alibi -tensorboard_checks valid_metrics - - name: Test LM training + - name: Testing architecture custom decoder only training run: | python eole/bin/main.py train \ -config eole/tests/data/lm_data.yaml \ diff --git a/README.md b/README.md index e5ac5d6a..25bc7d24 100644 --- a/README.md +++ b/README.md @@ -2,15 +2,21 @@ [![Documentation](https://img.shields.io/badge/docs-latest-blue.svg)](https://eole-nlp.github.io/eole) -Open language modeling toolkit based on [PyTorch](https://pytorch.org). +Open language modeling toolkit based on [PyTorch](https://pytorch.org) initially spun-off of OpenNMT-py -## 👷‍♂️🚧 Work in Progress +We aim to maintain the research-friendly approach of the original project while including latest architectures (LLMs) and various other techniques. +Our goal is to provide a comprehensive yet compact and modular codebase for experimenting with various types of language models (encoder, decoder, seq2seq). -[EOLE](https://github.com/eole-nlp/eole) is a spin-off of the [OpenNMT-py](https://github.com/opennmt/opennmt-py) project. We aim to maintain the research-friendly approach of the original project while updating the structure and expanding it to include new topics related to large language models (LLMs) and various other techniques. Our goal is to provide a comprehensive yet compact and modular codebase for experimenting with various types of language models (encoder, decoder, seq2seq). +## Latest developments ---- +- **Web-based (Google translator-like) interface** featuring the latest EuroLLM-8B-Instruct LLM: read more [here](https://github.com/eole-nlp/eole/tree/main/recipes/eurollm) +- **Estimator layer** which enables to rescore multiple beams in the same model. Read article [here](https://medium.com/p/05b00b271a47) and [here](https://medium.com/p/7dccfe167814) +- **Support Hugging Face Tokenizers** for better compatiblity +- **New recipes** for TowerInstruct-llama2 and TowerInstruct-Mistral +- **Support latest models** for Llama3.1, Gemma2, Pixtral +- **Replicate CometKiwi(XL/XXL)** Encoder+Estimator models -### Current State +## Work completed We have made significant progress in several areas: @@ -18,14 +24,13 @@ We have made significant progress in several areas: - **Command Line Entry Points**: Improved using structured subparsers for better organization. - **Reproducible Recipes**: Provided for widely used models and tasks, ensuring consistency and reliability. - **Core API Simplification**: Refined around the new configuration objects for ease of use. +- **Revamped Fast API based server**: see above example with EuroLLM-9B-Instruct ### Future Directions There are still several exciting avenues to explore: - **Further Simplification and Refactoring**: Continue enhancing the codebase for clarity and efficiency. -- **Inference Server**: Develop a robust solution for model inference. -- **Additional Recipes**: Expand the library of reproducible recipes. - **Documentation**: Enhance and expand the documentation for better user guidance. - **Test Coverage**: Improve testing to ensure code reliability and performance. - **Logging Enhancements**: Implement more sophisticated logging mechanisms. @@ -37,7 +42,7 @@ There are still several exciting avenues to explore: - **Versatile Training and Inference**: Train from scratch, finetune, and infer models of various architectures including Transformer Encoder/Decoder/EncoderDecoder and RNN EncoderDecoder. - **Dynamic Data Transforms**: Apply on-the-fly transformations in the dataloading logic for both training and inference. -- **Comprehensive LLM Support**: Includes converters for Llama, Mistral, Phi, OpenLlama, Redpajama, MPT-7B, and Falcon models. +- **Comprehensive LLM Support**: Includes converters for Llama, Mistral, Phi, Gemma ... - **Advanced Quantization**: Support for 8-bit and 4-bit quantization, along with LoRA adapters, with or without checkpointing, as well as mixed precision (FP16). - **Efficient Finetuning**: Finetune 7B and 13B models on a single RTX 24GB GPU using 4-bit quantization. - **Flexible Inference**: Perform inference in 4-bit or 8-bit using the same layer quantization methods as in finetuning. diff --git a/eole/bin/convert/convert_HF.py b/eole/bin/convert/convert_HF.py index a7987f36..f7e0a0f9 100755 --- a/eole/bin/convert/convert_HF.py +++ b/eole/bin/convert/convert_HF.py @@ -836,7 +836,7 @@ def get_weight(checkpoint, tensor_name): for target in targetlist: if target in key_maps[arch].keys(): source = key_maps[arch][target] - if type(source) == tuple: + if isinstance(source, tuple): srckey = source[0] srcmap = source[1] else: @@ -847,7 +847,7 @@ def get_weight(checkpoint, tensor_name): ) if w is not None: - if type(source) == tuple: + if isinstance(source, tuple): w = eval("w" + srcmap).contiguous() eole_safetensor[ eole_prefix + str(i) + target + param @@ -859,9 +859,8 @@ def get_weight(checkpoint, tensor_name): idx = 1 for p in ["weight", "bias"]: if ".input_layernorm." + p in key_maps[arch].keys(): - if ( - type(key_maps[arch][".input_layernorm." + p]) - == tuple + if isinstance( + key_maps[arch][".input_layernorm." + p], tuple ): w = get_weight( checkpoint, diff --git a/eole/bin/run/serve.py b/eole/bin/run/serve.py index 7387b6db..cae199eb 100644 --- a/eole/bin/run/serve.py +++ b/eole/bin/run/serve.py @@ -267,7 +267,7 @@ def infer(self, inputs, settings={}, is_chat=False): """ Run inference on the given inputs. """ - if type(inputs) == str: + if isinstance(inputs, str): inputs = [inputs] if not (self.loaded): self.load() diff --git a/eole/config/models.py b/eole/config/models.py index 8d564a29..5db288e4 100644 --- a/eole/config/models.py +++ b/eole/config/models.py @@ -287,9 +287,11 @@ class TransformerConfig(Config): @model_validator(mode="after") def _validate_transformer_config(self): + """ if self.position_encoding_type == PositionEncodingType.Rotary: if self.rope_config is None: self.rope_config = RotaryPositionConfig() + """ if self.add_qkvbias and "add_final_linear_bias" not in self.model_fields_set: self.update(add_final_linear_bias=True) return self @@ -503,40 +505,70 @@ def default_architecture(cls, data: Any) -> Any: return data def update_model_opts(self): - if self.embeddings is not None and self.embeddings.word_vec_size > 0: - self.embeddings.src_word_vec_size = self.embeddings.word_vec_size - self.embeddings.tgt_word_vec_size = self.embeddings.word_vec_size + update_dict = {} + if self.embeddings.position_encoding_type == PositionEncodingType.Rotary: + if not self.rope_config: + update_dict["rope_config"] = RotaryPositionConfig() + rope_config = update_dict["rope_config"] + else: + rope_config = self.rope_config + else: + rope_config = None - # Backward compatibility with "fix_word_vecs_*" opts - # We can probably drop this now... - # if hasattr(self, "fix_word_vecs_enc"): - # self.embeddings.freeze_word_vecs_enc = self.embeddings.fix_word_vecs_enc - # if hasattr(self, "fix_word_vecs_dec"): - # self.embeddings.freeze_word_vecs_dec = self.embeddings.fix_word_vecs_dec + if self.embeddings is not None and self.embeddings.word_vec_size > 0: + update_dict["embeddings"] = { + "src_word_vec_size": self.embeddings.word_vec_size, + "tgt_word_vec_size": self.embeddings.word_vec_size, + } + if self.embeddings is not None and "embeddings" in update_dict.keys(): + self.embeddings.update(**update_dict.pop("embeddings")) if ( getattr(self.encoder, "encoder_type", None) == "brnn" and self.decoder.decoder_type == "rnn" ): - self.decoder.bidirectional_encoder = True + update_dict["decoder"] = {"bidirectional_encoder": True} if self.encoder is not None: - self.encoder.src_word_vec_size = self.embeddings.src_word_vec_size + update_dict["encoder"] = { + "src_word_vec_size": self.embeddings.src_word_vec_size + } if getattr(self.encoder, "encoder_type", None) == "transformer": - self.encoder.position_encoding_type = ( - self.embeddings.position_encoding_type + update_dict["encoder"].update( + { + "position_encoding_type": self.embeddings.position_encoding_type, + "n_positions": self.embeddings.n_positions, + "rope_config": rope_config, + } ) - self.encoder.n_positions = self.embeddings.n_positions + update_dict[ + "position_encoding_type" + ] = self.embeddings.position_encoding_type + if self.encoder is not None and "encoder" in update_dict.keys(): + self.encoder.update(**update_dict.pop("encoder")) + if self.decoder is not None: - self.decoder.tgt_word_vec_size = self.embeddings.tgt_word_vec_size + update_dict["decoder"] = { + "tgt_word_vec_size": self.embeddings.tgt_word_vec_size + } if getattr(self.decoder, "decoder_type", None) in [ "transformer", "transformer_lm", ]: - self.decoder.position_encoding_type = ( - self.embeddings.position_encoding_type + update_dict["decoder"].update( + { + "position_encoding_type": self.embeddings.position_encoding_type, + "n_positions": self.embeddings.n_positions, + "rope_config": rope_config, + } ) - self.decoder.n_positions = self.embeddings.n_positions + update_dict[ + "position_encoding_type" + ] = self.embeddings.position_encoding_type + if self.decoder is not None and "decoder" in update_dict.keys(): + self.decoder.update(**update_dict.pop("decoder")) + + self.update(**update_dict) # causing some weird recursion issue in unit test, to investigate # if self.encoder is not None: @@ -584,7 +616,7 @@ def _validate_model_config(self): return self -class CustomModelConfig(BaseModelConfig): +class CustomModelConfig(TransformerConfig, BaseModelConfig): """ Wrap anything that does not fit a set common architecture. """ diff --git a/eole/config/run.py b/eole/config/run.py index 107cdd1c..24b05a4c 100644 --- a/eole/config/run.py +++ b/eole/config/run.py @@ -187,24 +187,26 @@ def _update_with_model_config(self): quant_type=training_config.quant_type, ) - model_config._validate_model_config() - # training_config._validate_running_config() # not sure it's needed - self.update( model=model_config, ) + update_dict = {} if "transforms" not in self.model_fields_set: - self.transforms = self._all_transform = transforms + update_dict["transforms"] = transforms + update_dict["_all_transform"] = transforms if "transforms_configs" not in self.model_fields_set: - self.transforms_configs = config_dict.get("transforms_configs", {}) + update_dict["transforms_configs"] = NestedAllTransformsConfig( + **config_dict.get("transforms_configs", {}) + ) if "compute_dtype" not in self.model_fields_set: self.compute_dtype = config_dict.get("training", {}).get( "compute_dtype", "fp16" ) for key, value in config_dict.get("inference", {}).items(): if key not in self.model_fields_set: - setattr(self, key, value) + update_dict[key] = value + self.update(**update_dict) class BuildVocabConfig( diff --git a/eole/config/training.py b/eole/config/training.py index d797d2f5..113bd827 100644 --- a/eole/config/training.py +++ b/eole/config/training.py @@ -212,10 +212,6 @@ class TrainingConfig( dropout_steps: List[int] = Field( default=[0], description="Steps at which dropout changes." ) - truncated_decoder: int = Field( - default=0, description="Truncated bptt." - ) # deprecated? - label_smoothing: float = Field( default=0.0, description="Label smoothing value epsilon. " diff --git a/eole/decoders/ensemble.py b/eole/decoders/ensemble.py index 2eec7cf2..ba2999fe 100644 --- a/eole/decoders/ensemble.py +++ b/eole/decoders/ensemble.py @@ -35,6 +35,7 @@ class EnsembleSrcEmb(nn.Module): def __init__(self, model_src_embs): super(EnsembleSrcEmb, self).__init__() self.model_src_embs = nn.ModuleList(model_src_embs) + self.word_padding_idx = model_src_embs[0].word_padding_idx def forward(self, src): src_emb = [model_src_emb(src) for model_src_emb in self.model_src_embs] @@ -48,10 +49,10 @@ def __init__(self, model_encoders): super(EnsembleEncoder, self).__init__() self.model_encoders = nn.ModuleList(model_encoders) - def forward(self, emb, mask=None): + def forward(self, emb, pad_mask=None, **kwargs): enc_out, enc_final_hs = zip( *[ - model_encoder(emb[i], mask) + model_encoder(emb[i], pad_mask=pad_mask, **kwargs) for i, model_encoder in enumerate(self.model_encoders) ] ) @@ -64,6 +65,7 @@ class EnsembleTgtEmb(nn.Module): def __init__(self, model_tgt_embs): super(EnsembleTgtEmb, self).__init__() self.model_tgt_embs = nn.ModuleList(model_tgt_embs) + self.word_padding_idx = model_tgt_embs[0].word_padding_idx def forward(self, tgt, step=None): tgt_emb = [model_tgt_emb(tgt, step) for model_tgt_emb in self.model_tgt_embs] @@ -164,9 +166,9 @@ class EnsembleModel(EncoderDecoderModel): """Dummy EncoderDecoderModel wrapping individual real EncoderDecoderModels.""" def __init__(self, models, raw_probs=False): - src_emb = EnsembleSrcEmb(model.src_emb for model in models) + src_emb = EnsembleSrcEmb([model.src_emb for model in models]) encoder = EnsembleEncoder(model.encoder for model in models) - tgt_emb = EnsembleTgtEmb(model.tgt_emb for model in models) + tgt_emb = EnsembleTgtEmb([model.tgt_emb for model in models]) decoder = EnsembleDecoder(model.decoder for model in models) hidden_size = models[0].hidden_size super(EnsembleModel, self).__init__( @@ -180,6 +182,7 @@ def __init__(self, models, raw_probs=False): [model.generator for model in models], raw_probs ) self.models = nn.ModuleList(models) + self.rope = models[0].rope def load_test_model(config, device_id=0): diff --git a/eole/decoders/rnn_decoder.py b/eole/decoders/rnn_decoder.py index 8564ac4e..ac0bdcfc 100644 --- a/eole/decoders/rnn_decoder.py +++ b/eole/decoders/rnn_decoder.py @@ -153,10 +153,10 @@ def forward(self, emb, enc_out, src_len=None, step=None, **kwargs): # Concatenates sequence of tensors along a new dimension. # NOTE: v0.3 to 0.4: dec_outs / attns[*] may not be list # since stack(Variable) was allowed. - if type(dec_outs) == list: + if isinstance(dec_outs, list): dec_outs = torch.stack(dec_outs, dim=1) for k in attns: - if type(attns[k]) == list: + if isinstance(attns[k], list): attns[k] = torch.stack(attns[k]) self.state["input_feed"] = dec_outs[:, -1, :].unsqueeze(0) diff --git a/eole/decoders/transformer_base.py b/eole/decoders/transformer_base.py index cec6044b..03bebc52 100644 --- a/eole/decoders/transformer_base.py +++ b/eole/decoders/transformer_base.py @@ -98,25 +98,23 @@ def update_dropout(self, dropout, attention_dropout): def _forward(self, *args, **kwargs): raise NotImplementedError - def _compute_dec_mask(self, tgt_pad_mask, future): + def _compute_attn_mask(self, tgt_pad_mask, future): tgt_len = tgt_pad_mask.size(-1) if not future: # Add triangular future_mask and pad_mask, result mask in (B, T, T). - future_mask = torch.ones( - [tgt_len, tgt_len], - device=tgt_pad_mask.device, - dtype=torch.uint8, + future_mask = torch.tril( + torch.ones( + (tgt_len, tgt_len), device=tgt_pad_mask.device, dtype=torch.bool + ), + diagonal=0, ) - future_mask = future_mask.tril_(0) if self.sliding_window > 0: future_mask = future_mask.triu_(-self.sliding_window) - future_mask = future_mask.bool() - future_mask = ~future_mask.view(1, tgt_len, tgt_len) - dec_mask = torch.gt(tgt_pad_mask + future_mask, 0) + attn_mask = ~tgt_pad_mask & future_mask.unsqueeze(0) else: # Only mask padding, result mask in (B, 1, T). - dec_mask = tgt_pad_mask - return dec_mask + attn_mask = ~tgt_pad_mask + return attn_mask class TransformerDecoderBase(DecoderBase): diff --git a/eole/decoders/transformer_decoder.py b/eole/decoders/transformer_decoder.py index 452d926f..42f8fddf 100644 --- a/eole/decoders/transformer_decoder.py +++ b/eole/decoders/transformer_decoder.py @@ -10,8 +10,7 @@ TransformerDecoderBase, ) from eole.modules.multi_headed_attn import ContextMHA -from eole.constants import LayerNorm, PositionEncodingType -from eole.modules.rope import RotaryPosition +from eole.constants import LayerNorm class TransformerDecoderLayer(TransformerDecoderLayerBase): @@ -81,15 +80,15 @@ def _forward( * attns ``(batch_size, head, T, src_len)`` """ - dec_mask = None + attn_mask = None src_pad_mask = src_pad_mask.unsqueeze(1) # [B,1,1,slen] if layer_in.size(1) > 1: # masking is necessary when sequence length is greater than one - dec_mask = self._compute_dec_mask(tgt_pad_mask, future) - dec_mask = dec_mask.unsqueeze(1) - dec_mask = dec_mask.expand(-1, -1, dec_mask.size(3), -1) - src_pad_mask = src_pad_mask.expand(-1, -1, dec_mask.size(3), -1) + attn_mask = self._compute_attn_mask(tgt_pad_mask, future) + attn_mask = attn_mask.unsqueeze(1) + attn_mask = attn_mask.expand(-1, -1, attn_mask.size(3), -1) + src_pad_mask = src_pad_mask.expand(-1, -1, attn_mask.size(3), -1) # mask now are (batch x 1 x tlen x s or t len) # 1 = heads to be expanded in MHA @@ -97,8 +96,7 @@ def _forward( self_attn, _ = self.self_attn( norm_layer_in, - mask=dec_mask, - sliding_window=self.sliding_window, + attn_mask=attn_mask, step=step, return_attn=return_attn, position_embeddings=position_embeddings, @@ -112,7 +110,7 @@ def _forward( enc_out, enc_out, norm_layer_in, - mask=src_pad_mask, + attn_mask=~src_pad_mask, return_attn=return_attn, ) if not self.shared_layer_norm: @@ -123,7 +121,11 @@ def _forward( else: norm_query = self.precontext_layernorm(self_attn + layer_in) ctx_attn, attns = self.context_attn( - enc_out, enc_out, norm_query, mask=src_pad_mask, return_attn=return_attn + enc_out, + enc_out, + norm_query, + attn_mask=~src_pad_mask, + return_attn=return_attn, ) if self.dropout_p > 0: ctx_attn = self.dropout(ctx_attn) @@ -153,9 +155,6 @@ def __init__( model_config, running_config=running_config ) - if model_config.position_encoding_type == PositionEncodingType.Rotary: - self.rope = RotaryPosition(model_config) - self.transformer_layers = nn.ModuleList( [ TransformerDecoderLayer( @@ -169,6 +168,7 @@ def __init__( self.layer_norm = LayerNorm[model_config.layer_norm]( model_config.hidden_size, eps=model_config.norm_eps ) + self._disable_cache() def forward(self, emb, **kwargs): """Decode, possibly stepwise.""" @@ -180,42 +180,19 @@ def forward(self, emb, **kwargs): assert tgt_pad_mask is not None, "TransformerDecoder requires a tgt pad mask" src_pad_mask = kwargs.pop("src_pad_mask", None) assert src_pad_mask is not None, "TransformerDecoder requires a src pad mask" - step = kwargs.pop("step", None) - - if enc_out is None: - enc_out = emb - if step == 0: - self._init_cache(enc_out.device) - elif step is None: - for layer in self.transformer_layers: - layer.self_attn.layer_cache = ( - False, - {"keys": torch.tensor([]), "values": torch.tensor([])}, - ) - layer.context_attn.layer_cache = ( - False, - {"keys": torch.tensor([]), "values": torch.tensor([])}, - ) - - if hasattr(self, "rope"): - position_embeddings = self.rope( - emb, - step=step, - device=emb.device, - ) - else: - position_embeddings = None - with_align = kwargs.pop("with_align", False) return_attn = with_align or kwargs.pop("return_attn", False) - + position_embeddings = kwargs.pop("position_embeddings", None) attn_aligns = [] + if step == 0: + self._enable_cache(enc_out.device) + for layer in self.transformer_layers: emb, attn, attn_align = layer( emb, - enc_out, + enc_out if enc_out is not None else emb, src_pad_mask, tgt_pad_mask, step=step, @@ -236,8 +213,7 @@ def forward(self, emb, **kwargs): # TODO change the way attns is returned dict => list or tuple (onnx) return emb, attns - def _init_cache(self, device): - + def _enable_cache(self, device): for layer in self.transformer_layers: # first value set to True triggered by the beginning of decoding # layer_cache becomes active in the MultiHeadedAttention fwd @@ -255,7 +231,14 @@ def _init_cache(self, device): "values": torch.tensor([], device=device), }, ) - if hasattr(layer.self_attn, "rope"): - layer.self_attn.rope = layer.self_attn.rope.to(device) - layer.self_attn.cos = layer.self_attn.cos.to(device) - layer.self_attn.sin = layer.self_attn.sin.to(device) + + def _disable_cache(self): + for layer in self.transformer_layers: + layer.self_attn.layer_cache = ( + False, + {"keys": torch.tensor([]), "values": torch.tensor([])}, + ) + layer.context_attn.layer_cache = ( + False, + {"keys": torch.tensor([]), "values": torch.tensor([])}, + ) diff --git a/eole/decoders/transformer_lm_decoder.py b/eole/decoders/transformer_lm_decoder.py index 40a66da4..2c72525b 100644 --- a/eole/decoders/transformer_lm_decoder.py +++ b/eole/decoders/transformer_lm_decoder.py @@ -10,8 +10,7 @@ TransformerDecoderLayerBase, TransformerDecoderBase, ) -from eole.constants import LayerNorm, PositionEncodingType -from eole.modules.rope import RotaryPosition +from eole.constants import LayerNorm class TransformerLMDecoderLayer(TransformerDecoderLayerBase): @@ -49,15 +48,15 @@ def _forward( * attns ``(batch_size, head, T, T)`` """ - dec_mask = None + attn_mask = None if layer_in.size(1) > 1: # Masking is necessary when sequence length is greater than one # The decoding has not started yet, # we compute the scores on the source tokens in one shot. - dec_mask = self._compute_dec_mask(pad_mask, future) - dec_mask = dec_mask.unsqueeze(1) - dec_mask = dec_mask.expand(-1, -1, dec_mask.size(3), -1) + attn_mask = self._compute_attn_mask(pad_mask, future) + attn_mask = attn_mask.unsqueeze(1) + attn_mask = attn_mask.expand(-1, -1, attn_mask.size(3), -1) # mask now are (batch x 1 x tlen x tlen) # 1 = heads to be expanded in MHA @@ -65,8 +64,7 @@ def _forward( attn_output, attns = self.self_attn( norm_layer_in, - mask=dec_mask, - sliding_window=self.sliding_window, + attn_mask=attn_mask, step=step, return_attn=return_attn, position_embeddings=position_embeddings, @@ -102,9 +100,6 @@ def __init__( ): super(TransformerLMDecoder, self).__init__(model_config) - if model_config.position_encoding_type == PositionEncodingType.Rotary: - self.rope = RotaryPosition(model_config) - self.transformer_layers = nn.ModuleList( [ TransformerLMDecoderLayer( @@ -118,6 +113,7 @@ def __init__( self.layer_norm = LayerNorm[model_config.layer_norm]( model_config.hidden_size, eps=model_config.norm_eps ) + self._disable_cache() def forward(self, emb, **kwargs): """Decode, possibly stepwise.""" @@ -126,36 +122,14 @@ def forward(self, emb, **kwargs): pad_mask = kwargs.pop("tgt_pad_mask", None) assert pad_mask is not None, "TransformerLMDecoder requires a pad mask" step = kwargs.pop("step", None) - - if hasattr(self, "rope"): - position_embeddings = self.rope( - emb, - step=step, - device=emb.device, - ) - else: - position_embeddings = None - - if step == 0: - # decoding mode. - # Initialize KV and key_pad_mask cache. - self._init_cache(emb.device, pad_mask) - elif step is None: - # training mode. - for layer in self.transformer_layers: - layer.self_attn.layer_cache = ( - False, - { - "keys": torch.tensor([]), - "values": torch.tensor([]), - "key_pad_mask": None, - }, - ) - with_align = kwargs.pop("with_align", False) return_attn = kwargs.pop("return_attn", False) return_attn = with_align or return_attn assert not with_align, "TransformerLMDecoder does not support align" + position_embeddings = kwargs.pop("position_embeddings", None) + + if step == 0: + self._enable_cache(emb.device, pad_mask) for layer in self.transformer_layers: emb, attn, _ = layer( @@ -174,7 +148,7 @@ def forward(self, emb, **kwargs): # TODO change the way attns is returned dict => list or tuple (onnx) return emb, attns - def _init_cache(self, device, mask): + def _enable_cache(self, device, pad_mask): for layer in self.transformer_layers: if hasattr(layer, "self_attn"): layer.self_attn.layer_cache = ( @@ -182,6 +156,18 @@ def _init_cache(self, device, mask): { "keys": torch.tensor([], device=device), "values": torch.tensor([], device=device), - "key_pad_mask": mask, + "key_pad_mask": pad_mask, + }, + ) + + def _disable_cache(self): + for layer in self.transformer_layers: + if hasattr(layer, "self_attn"): + layer.self_attn.layer_cache = ( + False, + { + "keys": torch.tensor([]), + "values": torch.tensor([]), + "key_pad_mask": None, }, ) diff --git a/eole/encoders/cnn_encoder.py b/eole/encoders/cnn_encoder.py index 0693c9ae..d905fcc1 100644 --- a/eole/encoders/cnn_encoder.py +++ b/eole/encoders/cnn_encoder.py @@ -37,7 +37,7 @@ def from_config(cls, model_config, running_config=None): running_config, # might be better to set this out of the model building logic (update_dropout call sometime in training) # noqa: E501 ) - def forward(self, emb, mask=None): + def forward(self, emb, **kwargs): """See :func:`EncoderBase.forward()`""" # batch x len x dim diff --git a/eole/encoders/encoder.py b/eole/encoders/encoder.py index 1e25b09a..dc48f5ac 100644 --- a/eole/encoders/encoder.py +++ b/eole/encoders/encoder.py @@ -16,13 +16,14 @@ class EncoderBase(nn.Module): def from_config(cls, model_config, running_config=None): raise NotImplementedError - def forward(self, emb, mask=None): + def forward(self, emb, **kwargs): """ Args: emb (FloatTensor): embeddings ``(batch, src_len, dim)`` - mask (BoolTensor): - mask ``(batch, maxlen)`` False when value, True when pad + **kwargs + pad_mask (BoolTensor): + pad_mask ``(batch, maxlen)`` False when value, True when pad Returns: (FloatTensor, FloatTensor, FloatTensor): diff --git a/eole/encoders/mean_encoder.py b/eole/encoders/mean_encoder.py index f927d71b..d1c4c1f7 100644 --- a/eole/encoders/mean_encoder.py +++ b/eole/encoders/mean_encoder.py @@ -23,15 +23,15 @@ def from_config(cls, model_config, running_config=None): # config = opt.model.encoder # MeanEncoderConfig return cls(model_config) - def forward(self, emb, mask=None): + def forward(self, emb, **kwargs): """See :func:`EncoderBase.forward()`""" - + pad_mask = kwargs.pop("pad_mask", None) batch, _, emb_dim = emb.size() - if mask is not None: + if pad_mask is not None: # we avoid padding while mean pooling - mask = (~mask).float() - mean = torch.bmm(mask.unsqueeze(1), emb).mean(1) + pad_mask = (~pad_mask).float() + mean = torch.bmm(pad_mask, emb).mean(1) else: mean = emb.mean(1) diff --git a/eole/encoders/rnn_encoder.py b/eole/encoders/rnn_encoder.py index 6a4b21e7..0ae2a86b 100644 --- a/eole/encoders/rnn_encoder.py +++ b/eole/encoders/rnn_encoder.py @@ -43,7 +43,7 @@ def from_config(cls, model_config, running_config=None): """Alternate constructor.""" return cls(model_config, running_config=running_config) - def forward(self, emb, mask=None): + def forward(self, emb, **kwargs): """See :func:`EncoderBase.forward()`""" enc_out, enc_final_hs = self.rnn(emb) diff --git a/eole/encoders/transformer.py b/eole/encoders/transformer.py index cb7f729c..95352557 100644 --- a/eole/encoders/transformer.py +++ b/eole/encoders/transformer.py @@ -7,8 +7,7 @@ from eole.encoders.encoder import EncoderBase from eole.modules.multi_headed_attn import SelfMHA from eole.modules.transformer_mlp import MLP -from eole.constants import LayerNorm, PositionEncodingType -from eole.modules.rope import RotaryPosition +from eole.constants import LayerNorm class TransformerEncoderLayer(nn.Module): @@ -46,11 +45,11 @@ def __init__( running_config=running_config, ) - def forward(self, layer_in, mask, position_embeddings=None): + def forward(self, layer_in, pad_mask, position_embeddings=None): """ Args: layer_in (FloatTensor): ``(batch_size, src_len, model_dim)`` - mask (LongTensor): ``(batch_size, 1, src_len)`` + pad_mask (LongTensor): ``(batch_size, 1, src_len)`` position_embeddings (FloatTensor): rotary position encodings, if any Returns: @@ -59,7 +58,7 @@ def forward(self, layer_in, mask, position_embeddings=None): """ norm_layer_in = self.input_layernorm(layer_in) context, _ = self.self_attn( - norm_layer_in, mask=mask, position_embeddings=position_embeddings + norm_layer_in, attn_mask=~pad_mask, position_embeddings=position_embeddings ) if self.dropout_p > 0: context = self.dropout(context) @@ -102,9 +101,6 @@ def __init__( ): super(TransformerEncoder, self).__init__() - if model_config.position_encoding_type == PositionEncodingType.Rotary: - self.rope = RotaryPosition(model_config) - self.transformer_layers = nn.ModuleList( [ TransformerEncoderLayer( @@ -127,23 +123,20 @@ def from_config(cls, model_config, running_config=None): running_config, ) - def forward(self, emb, mask=None): + def forward(self, emb, **kwargs): """See :func:`EncoderBase.forward()`""" + pad_mask = kwargs.pop("pad_mask", None) + assert pad_mask is not None, "TransformerEncoder requires a src pad mask" + position_embeddings = kwargs.pop("position_embeddings", None) enc_out = emb - mask = mask.unsqueeze(1).unsqueeze(1) - # mask is now (batch x 1 x 1 x maxlen) - mask = mask.expand(-1, -1, mask.size(3), -1) - # Padding mask is now (batch x 1 x maxlen x maxlen) + pad_mask = pad_mask.unsqueeze(1) # batch x 1 x 1 x maxlen + pad_mask = pad_mask.expand( + -1, -1, pad_mask.size(3), -1 + ) # batch x 1 x maxlen x maxlen # 1 to be expanded to number of heads in MHA - # Run the forward pass of every layer of the tranformer. - - if hasattr(self, "rope"): - position_embeddings = self.rope(emb, step=0, device=emb.device) - else: - position_embeddings = None for layer in self.transformer_layers: - enc_out = layer(enc_out, mask, position_embeddings=position_embeddings) + enc_out = layer(enc_out, pad_mask, position_embeddings=position_embeddings) enc_out = self.layer_norm(enc_out) return enc_out, None diff --git a/eole/inputters/text_utils.py b/eole/inputters/text_utils.py index 37c85277..e27a30bd 100644 --- a/eole/inputters/text_utils.py +++ b/eole/inputters/text_utils.py @@ -1,4 +1,7 @@ import torch + +# import torch.nn.functional as F +# import math from eole.constants import DefaultTokens, CorpusTask, ModelType from torch.nn.utils.rnn import pad_sequence from eole.utils.logging import logger @@ -180,7 +183,14 @@ def tensorify(vocabs, minibatch, device, left_pad=False): ] padidx = vocabs["src"][vocabs["specials"].get("pad_token", DefaultTokens.PAD)] tbatchsrc = pad_sequence(tbatchsrc, batch_first=True, padding_value=padidx) - + """ + This removes some recompiles in torch.dynamo, but slows down and make inference tricky + tbatchsrc = F.pad( + tbatchsrc, + (0, max(0, math.ceil(tbatchsrc.size(1) / 8) * 8 - tbatchsrc.size(1))), + value=padidx, + ) + """ if left_pad: tensor_batch["src"] = tbatchsrc.flip(dims=[1]) else: diff --git a/eole/models/model.py b/eole/models/model.py index 7609307f..5e2dfc6f 100644 --- a/eole/models/model.py +++ b/eole/models/model.py @@ -14,19 +14,33 @@ ) from torch.nn.utils import skip_init from torch.nn.init import xavier_uniform_, zeros_, uniform_, normal_ -from eole.utils.misc import use_gpu, sequence_mask, get_device +from eole.utils.misc import use_gpu, get_device from eole.inputters.inputter import dict_to_vocabs # copied from model_builder to facilitate tests, but should not live there in the end from eole.encoders import str2enc from eole.decoders import str2dec -from eole.constants import DefaultTokens +from eole.constants import DefaultTokens, PositionEncodingType from eole.modules.embeddings import Embeddings - +from eole.modules.rope import RotaryPosition from eole.models.model_saver import load_checkpoint from eole.modules.estimator import FeedForward +class NoOpPosition: + """A no-op position encoding callable.""" + + def update(self, *args, **kwargs): + return None + + +def build_rope(model_config): + if model_config.embeddings.position_encoding_type == PositionEncodingType.Rotary: + return RotaryPosition(model_config) + else: + return NoOpPosition() + + def build_encoder(model_config, running_config=None): """ Various encoder dispatcher function. @@ -113,6 +127,7 @@ def __init__(self, **kwargs): self.tgt_emb = kwargs.get("tgt_emb", None) self.add_estimator = kwargs.get("add_estimator", False) self.hidden_size = kwargs.get("hidden_size", None) + self.rope = kwargs.get("rope", None) self.share_decoder_embeddings = False if self.encoder is not None and self.src_emb is None: raise ValueError("An Encoder needs source Embeddings") @@ -352,12 +367,15 @@ def load_test_model(cls, config, device_id=0, model_path=None): checkpoint_model_config = checkpoint["config"].model # we actually need to merge inference opts and config here - config.model = checkpoint_model_config + update_dict = { + "model": checkpoint_model_config, + } # if not set in inference config, override with checkpoint (generalize to more fields?) if "quant_type" not in config.model_fields_set: - config.quant_type = checkpoint["config"].training.quant_type + update_dict["quant_type"] = checkpoint["config"].training.quant_type if "quant_layers" not in config.model_fields_set: - config.quant_layers = checkpoint["config"].training.quant_layers + update_dict["quant_layers"] = checkpoint["config"].training.quant_layers + config.update(**update_dict) vocabs = None # not super clean inheritance anti-pattern, # might be enhanced if checkpoint loading is split from model instanciation @@ -475,7 +493,7 @@ def from_config( return model, vocabs, model_config - def forward(self, src, tgt, src_len, bptt=False, with_align=False): + def forward(self, src, tgt, src_len, with_align=False): """Forward propagate a `src` and `tgt` pair for training. Args: @@ -486,8 +504,6 @@ def forward(self, src, tgt, src_len, bptt=False, with_align=False): tgt (LongTensor): A target sequence passed to decoder. Size ``(batch, tgt_len)``. src_len(LongTensor): The src lengths, pre-padding ``(batch,)``. - bptt (Boolean): A flag indicating if truncated bptt is set. - If bptt is false then init decoder state. with_align (Boolean): A flag indicating whether output alignment, Only valid for transformer decoder. @@ -671,7 +687,9 @@ def load_safe_state_dict( ) keyfound[name + "." + param_name] = True elif strict and ( - "lora" not in param_name and "slopes" not in param_name + "lora" not in param_name + and "slopes" not in param_name + and "rope" not in name ): # Let's warn instead of just passing logger.info( @@ -746,6 +764,8 @@ class EncoderDecoderModel(BaseModel): def __init__(self, **kwargs): super(EncoderDecoderModel, self).__init__(**kwargs) self.tgt_shift = 1 + self.src_pad_idx = self.src_emb.word_padding_idx + self.tgt_pad_idx = self.tgt_emb.word_padding_idx # we might want to disable this constructor some way if self.encoder is None or self.decoder is None: raise ValueError( @@ -773,10 +793,11 @@ def build_blocks(cls, model_config, vocabs, running_config=None): tgt_emb=tgt_emb, add_estimator=model_config.add_estimator, hidden_size=model_config.decoder.hidden_size, + rope=build_rope(model_config), ) # from there, the base blocks exist, and the rest is done in the from_opt from base class - def forward(self, src, tgt, src_len, bptt=False, with_align=False): + def forward(self, src, tgt, src_len, with_align=False): """An EncoderDecoderModel forward the src side to the encoder. Then the output of encoder ``enc_out`` is forwarded to the decoder along with the target excluding the last token. @@ -784,16 +805,17 @@ def forward(self, src, tgt, src_len, bptt=False, with_align=False): * enc_final_hs in the case of RNNs * enc_out + enc_final_hs in the case of CNNs * src in the case of Transformer""" - mask = sequence_mask(src_len) - enc_out, enc_final_hs = self.encoder(self.src_emb(src), mask=mask) - if not bptt: - self.decoder.init_state(src=src, enc_out=enc_out, enc_final_hs=enc_final_hs) - - pad_idx = self.src_emb.word_padding_idx - src_pad_mask = src.eq(pad_idx).unsqueeze(1) # [B, 1, T_src] - tgt_pad_mask = tgt[:, :-1].eq(pad_idx).unsqueeze(1) # [B, 1, T_tgt] - + src_pad_mask = src.eq(self.src_pad_idx).unsqueeze(1) # [B, 1, T_src] + position_embeddings = self.rope.update(src.size(1), step=0) + enc_out, enc_final_hs = self.encoder( + self.src_emb(src), + pad_mask=src_pad_mask, + position_embeddings=position_embeddings, + ) + self.decoder.init_state(src=src, enc_out=enc_out, enc_final_hs=enc_final_hs) dec_in = tgt[:, :-1] + tgt_pad_mask = dec_in.eq(self.tgt_pad_idx).unsqueeze(1) # [B, 1, T_tgt] + dec_out, attns = self.decoder( self.tgt_emb(dec_in), enc_out=enc_out, @@ -801,10 +823,11 @@ def forward(self, src, tgt, src_len, bptt=False, with_align=False): with_align=with_align, src_pad_mask=src_pad_mask, tgt_pad_mask=tgt_pad_mask, + position_embeddings=position_embeddings, ) if self.add_estimator: # we take the average of dec_out using the pad mask - pad_mask2 = ~dec_in.eq(pad_idx) + pad_mask2 = ~dec_in.eq(self.tgt_pad_idx) in_estim2 = (dec_out * pad_mask2.unsqueeze(-1).float()).sum( dim=1 ) / pad_mask2.sum(dim=1, keepdim=True).float() @@ -831,6 +854,7 @@ class DecoderModel(BaseModel): def __init__(self, **kwargs): super(DecoderModel, self).__init__(**kwargs) self.tgt_shift = 0 + self.pad_idx = self.tgt_emb.word_padding_idx if self.encoder is not None: raise ValueError("DecoderModel should not be used" "with an encoder") if self.decoder is None: @@ -847,28 +871,27 @@ def build_blocks(cls, model_config, vocabs, running_config=None): tgt_emb=tgt_emb, add_estimator=model_config.add_estimator, hidden_size=model_config.decoder.hidden_size, + rope=build_rope(model_config), ) # from there, the base blocks exist, and the rest is done in the from_opt from base class - def forward(self, src, tgt, src_len, bptt=False, with_align=False): + def forward(self, src, _, src_len, with_align=False): """A DecoderModel forward the src side to the decoder along with the source lengths vector. It is a decoder only LM (cf GPT-2)""" - if not bptt: - self.decoder.init_state() - emb = self.tgt_emb(src) - pad_idx = self.tgt_emb.word_padding_idx - pad_mask = src.eq(pad_idx).unsqueeze(1) # [B, 1, T_tgt] + self.decoder.init_state() + position_embeddings = self.rope.update(src.size(1), step=0) dec_out, attns = self.decoder( - emb, + self.tgt_emb(src), enc_out=None, src_len=src_len, with_align=with_align, - tgt_pad_mask=pad_mask, + tgt_pad_mask=src.eq(self.pad_idx).unsqueeze(1), + position_embeddings=position_embeddings, ) if self.add_estimator: # we take the average of dec_out using the pad mask - pad_mask2 = ~src.eq(pad_idx) + pad_mask2 = ~src.eq(self.pad_idx) in_estim2 = (dec_out * pad_mask2.unsqueeze(-1).float()).sum( dim=1 ) / pad_mask2.sum(dim=1, keepdim=True).float() @@ -893,6 +916,7 @@ class EncoderModel(BaseModel): def __init__(self, **kwargs): super(EncoderModel, self).__init__(**kwargs) self.tgt_shift = 1 + self.pad_idx = self.src_emb.word_padding_idx if self.decoder is not None: raise ValueError("EncoderModel should not be used" "with a decoder") if self.encoder is None: @@ -909,19 +933,24 @@ def build_blocks(cls, model_config, vocabs, running_config=None): src_emb=src_emb, add_estimator=model_config.add_estimator, hidden_size=model_config.encoder.hidden_size, + rope=build_rope(model_config), ) # from there, the base blocks exist, and the rest is done in the from_opt from base class - def forward(self, src, tgt, src_len, bptt=False, with_align=False): + def forward(self, src, _, src_len, with_align=False): """An EncoderModel encodes the source sentence to build hidden states""" - mask = sequence_mask(src_len) - enc_out, enc_final_hs = self.encoder(self.src_emb(src), mask=mask) + pad_mask = src.eq(self.pad_idx).unsqueeze(1) # [B, 1, T_src] + position_embeddings = self.rope.update(src.size(1), step=0) + enc_out, enc_final_hs = self.encoder( + self.src_emb(src), + pad_mask=pad_mask, + position_embeddings=position_embeddings, + ) if self.add_estimator: # Version with average """ - pad_idx = self.tgt_emb.word_padding_idx - pad_mask1 = ~src.eq(pad_idx) + pad_mask1 = ~src.eq(self.pad_idx) in_estim1 = (enc_out * pad_mask1.unsqueeze(-1).float()).sum( dim=1 ) / pad_mask1.sum(dim=1, keepdim=True).float() diff --git a/eole/modules/multi_headed_attn.py b/eole/modules/multi_headed_attn.py index ddd9d043..313b78bb 100644 --- a/eole/modules/multi_headed_attn.py +++ b/eole/modules/multi_headed_attn.py @@ -5,7 +5,6 @@ from torch import Tensor from typing import Optional, Tuple from torch.nn.functional import scaled_dot_product_attention -from torch.nn.attention import SDPBackend, sdpa_kernel from torch.utils.checkpoint import checkpoint from torch.nn.utils import skip_init from .alibi_position_bias import AlibiPositionalBias @@ -221,7 +220,7 @@ def unshape(x: Tensor) -> Tensor: -> [batchsize x length x modeldim] """ x_0, x_1, _, x_3 = x.size() - return x.transpose(1, 2).contiguous().view(x_0, -1, x_1 * x_3) + return x.transpose(1, 2).reshape(x_0, -1, x_1 * x_3) class MultiHeadedAttention(torch.nn.Module): @@ -316,15 +315,20 @@ def __init__( self.relative_positions_buckets = model_config.relative_positions_buckets self.layer_cache = ( False, - {"keys": torch.tensor([]), "values": torch.tensor([])}, + { + "keys": torch.tensor([]), + "values": torch.tensor([]), + "key_pad_mask": None, + }, ) + self.sliding_window = model_config.sliding_window # TODO find a cleaner way to initialize? self.relative_positions_embeddings = None self.relative_attention_bias = None self.rotary_interleave = None - if model_config.relative_positions_buckets > 0: + if self.relative_positions_buckets > 0: self.relative_attention_bias = nn.Embedding( - model_config.relative_positions_buckets, self.heads + self.relative_positions_buckets, self.heads ) self.relative_positions_embeddings = None elif self.position_encoding_type == PositionEncodingType.Relative: @@ -333,7 +337,7 @@ def __init__( # relative_key / relative_value or only # relative_key. We implemented the same embed # for both. - vocab_size = model_config.n_positions * 2 + 1 + vocab_size = self.n_positions * 2 + 1 self.relative_positions_embeddings = nn.Embedding( vocab_size, self.dim_per_head ) @@ -343,7 +347,7 @@ def __init__( else: self.rotary_dim = model_config.rope_config.rotary_dim self.rotary_interleave = model_config.rope_config.rotary_interleave - elif model_config.position_encoding_type == PositionEncodingType.Alibi: + elif self.position_encoding_type == PositionEncodingType.Alibi: self.alibi = AlibiPositionalBias(self.heads) self.maybe_ckpt = ( @@ -398,10 +402,9 @@ def _prepare_inputs( query = shape(query, self.dim_per_head) if self.position_encoding_type == PositionEncodingType.Rotary: - start_pos = 0 seqlen = query.size(2) - cos = position_embeddings[0][start_pos : start_pos + seqlen] - sin = position_embeddings[1][start_pos : start_pos + seqlen] + cos = position_embeddings[0][:seqlen] + sin = position_embeddings[1][:seqlen] query, key = apply_rotary_emb( query, key, (cos, sin), interleave=self.rotary_interleave ) @@ -412,9 +415,7 @@ def _compute_attention( key: Tensor, value: Tensor, query: Tensor, - mask: Optional[Tensor] = None, - attn_type: Optional[str] = "self", - sliding_window: Optional[int] = 0, + attn_mask: Optional[Tensor] = None, return_attn: Optional[bool] = False, ) -> Tuple[Tensor, Tensor]: """ @@ -427,8 +428,8 @@ def _compute_attention( value vectors ``(batch, head, key_len, dim)`` query (Tensor): set of `query_len` query vectors ``(batch, head, query_len, dim)`` - mask: binary mask 1/0 indicating which keys have - zero / non-zero attention ``(batch, 1, query_len, key_len)`` + attn_mask (bool Tensor): True = position needs attention + ``(batch, 1, query_len, key_len)`` Returns: (Tensor, Tensor): @@ -455,15 +456,14 @@ def _compute_attention( and query.device.type != "cpu" ): # Apply pytorch scaled_dot_product_attention. - with sdpa_kernel([SDPBackend.MATH, SDPBackend.EFFICIENT_ATTENTION]): - attn_output = scaled_dot_product_attention( - query, - key, - value, - ~mask if mask is not None else None, - self.dropout_p, - is_causal=False, - ) + attn_output = scaled_dot_product_attention( + query, + key, + value, + attn_mask if attn_mask is not None else None, + self.dropout_p, + is_causal=False, + ) attn = None else: query /= sqrt(self.dim_per_head) @@ -509,11 +509,13 @@ def _compute_attention( scores = scores.float() - if mask is not None: + if attn_mask is not None: # not 100% necessary but expand to nb of heads - mask = mask.expand(-1, self.heads // self.parallel_gpu, -1, -1) + attn_mask = attn_mask.expand( + -1, self.heads // self.parallel_gpu, -1, -1 + ) # now mask and scores have the same shape - scores = scores.masked_fill(mask, -1e18) + scores = scores.masked_fill(~attn_mask, -1e18) # 3) Apply attention dropout and compute context vectors. attn = self.softmax(scores).to(query.dtype) @@ -553,48 +555,73 @@ def __init__( self.n_positions = model_config.n_positions super(SelfMHA, self).__init__(model_config, running_config, is_decoder) + def _expand_cache(self, add_length, step): + b, h, l, dph = self.layer_cache[1]["keys"].shape + if step >= l: + ktype = self.layer_cache[1]["keys"].dtype + kdev = self.layer_cache[1]["keys"].device + self.layer_cache[1]["keys"] = torch.cat( + ( + self.layer_cache[1]["keys"], + torch.zeros((b, h, add_length, dph), device=kdev, dtype=ktype), + ), + dim=2, + ) + self.layer_cache[1]["values"] = torch.cat( + ( + self.layer_cache[1]["values"], + torch.zeros((b, h, add_length, dph), device=kdev, dtype=ktype), + ), + dim=2, + ) + if self.sliding_window > 0 and l > self.sliding_window: + self.layer_cache[1]["keys"] = self.layer_cache[1]["keys"][:, :, 1:, :] + self.layer_cache[1]["values"] = self.layer_cache[1]["values"][:, :, 1:, :] + return self.sliding_window + else: + return step + def _prepare_inputs_w_cache( self, query, key, value, step: Optional[int] = 0, - sliding_window: Optional[int] = 0, position_embeddings=None, ): - start_pos = step - seqlen = query.size(2) if self.position_encoding_type == PositionEncodingType.Rotary: - cos = position_embeddings[0][start_pos : start_pos + seqlen] - sin = position_embeddings[1][start_pos : start_pos + seqlen] + seqlen = query.size(2) + cos = position_embeddings[0][step : step + seqlen] + sin = position_embeddings[1][step : step + seqlen] query, key = apply_rotary_emb( query, key, (cos, sin), interleave=self.rotary_interleave ) - # update the cache - if self.layer_cache[1]["keys"].numel() != 0: - key = torch.cat((self.layer_cache[1]["keys"], key), dim=2) - value = torch.cat((self.layer_cache[1]["values"], value), dim=2) - if sliding_window > 0 and key.size(2) > sliding_window: - key = key[:, :, 1:, :] - value = value[:, :, 1:, :] - # mask keys for LM left padding by batch + if step == 0: + # mask keys for LM left padding by batch key_pad_mask = self.layer_cache[1].get("key_pad_mask", None) if key_pad_mask is not None: x = key_pad_mask.expand(-1, key.size(1), -1) x = x.unsqueeze(3).expand(-1, -1, -1, key.size(3)) key = key.masked_fill(x, 0) - - self.layer_cache[1]["keys"] = key - self.layer_cache[1]["values"] = value - - return key, value, query + # init cache with initial key, value + self.layer_cache[1]["keys"] = key + self.layer_cache[1]["values"] = value + return key, value, query + else: + cache_len = self._expand_cache(32, step) + self.layer_cache[1]["keys"][:, :, cache_len, :] = key[:, :, 0, :] + self.layer_cache[1]["values"][:, :, cache_len, :] = value[:, :, 0, :] + return ( + self.layer_cache[1]["keys"][:, :, : cache_len + 1, :], + self.layer_cache[1]["values"][:, :, : cache_len + 1, :], + query, + ) def forward( self, query: Tensor, - mask: Optional[Tensor] = None, - sliding_window: Optional[int] = 0, + attn_mask: Optional[Tensor] = None, step: Optional[int] = 0, return_attn: Optional[bool] = False, position_embeddings=None, @@ -607,14 +634,11 @@ def forward( key = shape(key, self.dim_per_head) value = shape(value, self.dim_per_head) query = shape(query, self.dim_per_head) - start_pos = step if ( step == 0 or not self.flash or self.position_encoding_type in [PositionEncodingType.Relative, PositionEncodingType.Alibi] - or query.size(0) - > 128 # it seems for large batch size flash not optimum or query.dtype not in [torch.float16, torch.bfloat16] # to match with flash ): @@ -623,52 +647,15 @@ def forward( key, value, step=step, - sliding_window=sliding_window, position_embeddings=position_embeddings, ) else: # Fast path with flash_attn_with_kvcache - if start_pos >= self.layer_cache[1]["keys"].size(2): - self.layer_cache[1]["keys"] = torch.cat( - [ - self.layer_cache[1]["keys"], - torch.zeros( - self.layer_cache[1]["keys"].shape[:-2] - + (32,) - + self.layer_cache[1]["keys"].shape[-1:], - device=query.device, - dtype=query.dtype, - ), - ], - dim=-2, - ) - self.layer_cache[1]["values"] = torch.cat( - [ - self.layer_cache[1]["values"], - torch.zeros( - self.layer_cache[1]["values"].shape[:-2] - + (32,) - + self.layer_cache[1]["values"].shape[-1:], - device=query.device, - dtype=query.dtype, - ), - ], - dim=-2, - ) - if sliding_window > 0 and key.size(2) > sliding_window: - self.layer_cache[1]["keys"] = self.layer_cache[1]["keys"][ - :, :, 1:, : - ] - self.layer_cache[1]["values"] = self.layer_cache[1]["values"][ - :, :, 1:, : - ] + cache_len = self._expand_cache(32, step) if position_embeddings is not None: - cos = position_embeddings[0][:, : self.rotary_dim // 2].to( - query.dtype - ) - sin = position_embeddings[1][:, : self.rotary_dim // 2].to( - query.dtype - ) + rotdim = self.rotary_dim // 2 + cos = position_embeddings[0][:, :rotdim].to(query.dtype) + sin = position_embeddings[1][:, :rotdim].to(query.dtype) else: cos = None sin = None @@ -678,9 +665,9 @@ def forward( self.layer_cache[1]["values"].transpose(1, 2), key.transpose(1, 2), value.transpose(1, 2), - rotary_cos=cos, - rotary_sin=sin, - cache_seqlens=step, + rotary_cos=cos.contiguous(), + rotary_sin=sin.contiguous(), + cache_seqlens=cache_len, rotary_interleaved=self.rotary_interleave, ).transpose(1, 2) attn_output = self.final_linear(unshape(context)) @@ -702,9 +689,7 @@ def forward( key, value, query, - mask=mask, - attn_type="self", - sliding_window=sliding_window, + attn_mask=attn_mask, return_attn=return_attn, ) @@ -737,8 +722,7 @@ def forward( key: Tensor, value: Tensor, query: Tensor, - mask: Optional[Tensor] = None, - sliding_window: Optional[int] = 0, + attn_mask: Optional[Tensor] = None, step: Optional[int] = 0, return_attn: Optional[bool] = False, ) -> Tuple[Tensor, Tensor]: @@ -753,8 +737,6 @@ def forward( key, value, query, - mask=mask, - attn_type="context", - sliding_window=sliding_window, + attn_mask=attn_mask, return_attn=return_attn, ) diff --git a/eole/modules/rope.py b/eole/modules/rope.py index 8ba88ae7..cafa5554 100644 --- a/eole/modules/rope.py +++ b/eole/modules/rope.py @@ -42,14 +42,21 @@ def __init__(self, model_config): rotary_dim = model_config.rope_config.rotary_dim self.rotary_interleave = model_config.rope_config.rotary_interleave self.rotary_theta = model_config.rope_config.rotary_theta - self.inv_freq = 1.0 / ( + inv_freq = 1.0 / ( self.rotary_theta ** (torch.arange(0, rotary_dim, 2).float() / rotary_dim) ) + self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO: extend with other scaling types if getattr(self.model_config.rope_config, "scaling_type", None) == "llama3": self.llama3_scaling() - # cache rope tensor to limit unnecessary computations - self.rope = None + tmax = torch.arange(1024) + rope = torch.outer(tmax, inv_freq) + cos = torch.cos(rope) + sin = torch.sin(rope) + cos = torch.cat((cos, cos), dim=-1) # Double the size by repeating `cos` + sin = torch.cat((sin, sin), dim=-1) # Double the size by repeating `sin` + self.register_buffer("cos", cos, persistent=False) + self.register_buffer("sin", sin, persistent=False) def llama3_scaling(self): """ @@ -92,17 +99,15 @@ def llama3_scaling(self): ) * inv_freq_llama / factor + smooth_factor * inv_freq_llama is_medium_freq = ~(wavelen < high_freq_wavelen) * ~(wavelen > low_freq_wavelen) inv_freq_llama = torch.where(is_medium_freq, smoothed_inv_freq, inv_freq_llama) - self.inv_freq = inv_freq_llama + self.register_buffer("inv_freq", inv_freq_llama, persistent=False) - def forward(self, emb, step=0, device=None, prefetch=1024): + def update(self, maxseqlen, step=0, prefetch=1024): """ Computes the rotary position embeddings for a given input. Args: - emb: The input embeddings to which rotary embeddings will be applied. + maxseqlen: max seq length of the input embeddings. step: The current step or position within the sequence. Defaults to 0. - device: The device on which the computations should be performed. - If None, defaults to the device of `self.inv_freq`. offset: An optional offset to apply to the position indices. This is used for the specific `flash_attn_with_kvcache` path, which requires processes by chunks of 32 tokens. Defaults to 0. @@ -118,25 +123,24 @@ def forward(self, emb, step=0, device=None, prefetch=1024): """ if step is None: step = 0 - maxseqlen = emb.size(1) offset = ( 32 # make sure we have at least 32 positions for flash_attn_with_kvcache ) # This could probably a bit cleaner/homogenized with the offset case - if self.rope is not None: - if self.rope[0].size(0) >= max(offset + step, 0) + maxseqlen: - return self.rope - else: - maxseqlen = maxseqlen + prefetch + if self.cos.size(0) >= max(offset + step, 0) + maxseqlen: + return self.cos, self.sin else: - self.inv_freq = self.inv_freq.to(device) - tmax = torch.arange(max(offset + step, 0) + maxseqlen, device=device) + maxseqlen = maxseqlen + prefetch + + tmax = torch.arange( + max(offset + step, 0) + maxseqlen, device=self.inv_freq.device + ) rope = torch.outer(tmax, self.inv_freq) cos = torch.cos(rope) sin = torch.sin(rope) cos = torch.cat((cos, cos), dim=-1) # Double the size by repeating `cos` sin = torch.cat((sin, sin), dim=-1) # Double the size by repeating `sin` - # Cache the result for reuse - self.rope = (cos, sin) + self.register_buffer("cos", cos, persistent=False) + self.register_buffer("sin", sin, persistent=False) return cos, sin diff --git a/eole/predict/__init__.py b/eole/predict/__init__.py index 4887f5f3..847e5bf4 100644 --- a/eole/predict/__init__.py +++ b/eole/predict/__init__.py @@ -28,7 +28,7 @@ def build_predictor(config, device_id=0, report_score=True, logger=None): ) vocabs, model, model_config = load_test_model(config, device_id) - config.model = model_config + config.update(model=model_config) scorer = GNMTGlobalScorer.from_config(config) diff --git a/eole/predict/encoder.py b/eole/predict/encoder.py index 292245c4..d84d9c5b 100644 --- a/eole/predict/encoder.py +++ b/eole/predict/encoder.py @@ -3,7 +3,6 @@ from eole.constants import ModelType from eole.predict.greedy_search import GreedySearch from eole.predict.beam_search import BeamSearch -from eole.utils.misc import sequence_mask class Encoder(Inference): @@ -75,9 +74,9 @@ def _run_encoder(self, batch): src = batch["src"] src_len = batch["srclen"] batch_size = len(batch["srclen"]) - mask = sequence_mask(src_len) emb = self.model.src_emb(src) - enc_out, enc_final_hs = self.model.encoder(emb, mask) + pad_mask = src.eq(self._src_pad_idx).unsqueeze(1) + enc_out, enc_final_hs = self.model.encoder(emb, pad_mask=pad_mask) if src_len is None: assert not isinstance( @@ -142,16 +141,14 @@ def _predict_batch_with_strategy(self, batch, decode_strategy): return results - def _score_target(self, batch, enc_out, src_len, src_map): + def _score_target(self, batch, enc_out, src_len): tgt = batch["tgt"] tgt_in = tgt[:, :-1, :] log_probs, attn = self._decode_and_generate( tgt_in, enc_out, - batch, src_len=src_len, - src_map=src_map, ) log_probs[:, :, self._tgt_pad_idx] = 0 diff --git a/eole/predict/generator.py b/eole/predict/generator.py index 41887508..17df6c28 100644 --- a/eole/predict/generator.py +++ b/eole/predict/generator.py @@ -141,10 +141,8 @@ def _predict_batch_with_strategy(self, batch, decode_strategy, left_pad=True): log_probs, attn = self._decode_and_generate( decoder_input, None, - batch, src_len=decode_strategy.src_len, step=step if step == 0 else step + max(src_len.tolist()), - batch_offset=decode_strategy.batch_offset, ) if step == 0: @@ -180,11 +178,14 @@ def _predict_batch_with_strategy(self, batch, decode_strategy, left_pad=True): dec_in = torch.cat((src, dec_in), 1) tgt_pad_mask = dec_in.eq(self._tgt_pad_idx).unsqueeze(1) # [B, T_tgt] emb = self.model.tgt_emb(dec_in) + self.model.decoder._disable_cache() + position_embeddings = self.model.rope.update(dec_in.size(1), step=0) dec_out, _ = self.model.decoder( emb, enc_out=None, return_attn=False, tgt_pad_mask=tgt_pad_mask, + position_embeddings=position_embeddings, ) pad_mask = ~dec_in.eq(self._tgt_pad_idx) in_estim = (dec_out * pad_mask.unsqueeze(-1).float()).sum( diff --git a/eole/predict/inference.py b/eole/predict/inference.py index c977e426..f103fc6f 100644 --- a/eole/predict/inference.py +++ b/eole/predict/inference.py @@ -99,6 +99,9 @@ def __init__( self._tgt_pad_idx = vocabs["tgt"].lookup_token( vocabs.get("specials", {}).get("pad_token", DefaultTokens.PAD) ) + self._src_pad_idx = vocabs["src"].lookup_token( + vocabs.get("specials", {}).get("pad_token", DefaultTokens.PAD) + ) self._tgt_bos_idx = vocabs["tgt"].lookup_token( vocabs.get("specials", {}).get("bos_token", "") ) @@ -650,10 +653,8 @@ def _decode_and_generate( self, decoder_in, enc_out, - batch, src_len, step=None, - batch_offset=None, return_attn=False, ): @@ -661,6 +662,7 @@ def _decode_and_generate( # and [batch, src_len, hidden] as enc_out # in case of inference tgt_len = 1, batch = beam times batch_size # in case of Gold Scoring tgt_len = actual length, batch = 1 batch + # we still rely on src_len here because updated at each beam search step if isinstance(enc_out, tuple): src_max_len = enc_out[0].size(1) src_pad_mask = sequence_mask(src_len, src_max_len).unsqueeze( @@ -674,16 +676,16 @@ def _decode_and_generate( else: src_pad_mask = None tgt_pad_mask = decoder_in.eq(self._tgt_pad_idx).unsqueeze(1) # [B, 1, T_tgt] - - emb = self.model.tgt_emb(decoder_in, step=step) + position_embeddings = self.model.rope.update(decoder_in.size(1), step=step) dec_out, dec_attn = self.model.decoder( - emb, + self.model.tgt_emb(decoder_in, step=step), enc_out=enc_out, src_len=src_len, step=step, return_attn=self.global_scorer.has_cov_pen or return_attn, src_pad_mask=src_pad_mask, tgt_pad_mask=tgt_pad_mask, + position_embeddings=position_embeddings, ) # Generator forward. if "std" in dec_attn: diff --git a/eole/predict/translator.py b/eole/predict/translator.py index c2ddf6fb..81b116c0 100644 --- a/eole/predict/translator.py +++ b/eole/predict/translator.py @@ -4,7 +4,7 @@ from torch.nn.utils.rnn import pad_sequence from eole.predict.beam_search import BeamSearch from eole.predict.greedy_search import GreedySearch -from eole.utils.misc import tile, sequence_mask +from eole.utils.misc import tile from eole.utils.alignment import extract_alignment from eole.predict.inference import Inference @@ -59,18 +59,17 @@ def _align_forward(self, batch, predictions): # (4) reshape and apply pad masking in the target sequence tgt = batch_tgt_idxs.view(-1, batch_tgt_idxs.size(-1)) - src_pad_idx = self.model.src_emb.word_padding_idx - tgt_pad_idx = self.model.tgt_emb.word_padding_idx - src_pad_mask = src.eq(src_pad_idx).unsqueeze(1) - tgt_pad_mask = tgt[:, :-1].eq(tgt_pad_idx).unsqueeze(1) - + src_pad_mask = src.eq(self._src_pad_idx).unsqueeze(1) + tgt_pad_mask = tgt[:, :-1].eq(self._tgt_pad_idx).unsqueeze(1) dec_in = tgt[:, :-1] + position_embeddings = self.model.rope.update(dec_in.size(1), step=0) _, attns = self.model.decoder( self.model.tgt_emb(dec_in), enc_out=enc_out, src_pad_mask=src_pad_mask, tgt_pad_mask=tgt_pad_mask, with_align=True, + position_embeddings=position_embeddings, ) alignment_attn = attns["align"] # ``(B, tgt_len-1, src_len)`` @@ -141,9 +140,12 @@ def _run_encoder(self, batch): src = batch["src"] src_len = batch["srclen"] batch_size = len(batch["srclen"]) - mask = sequence_mask(src_len) emb = self.model.src_emb(src) - enc_out, enc_final_hs = self.model.encoder(emb, mask) + pad_mask = src.eq(self._src_pad_idx).unsqueeze(1) # [B, 1, T_src] + position_embeddings = self.model.rope.update(src.size(1), step=0) + enc_out, enc_final_hs = self.model.encoder( + emb, pad_mask=pad_mask, position_embeddings=position_embeddings + ) if src_len is None: assert not isinstance( @@ -212,10 +214,8 @@ def _translate_batch_with_strategy(self, batch, decode_strategy): log_probs, attn = self._decode_and_generate( decoder_input, enc_out, - batch, src_len=decode_strategy.src_len, step=step, - batch_offset=decode_strategy.batch_offset, return_attn=decode_strategy.return_attention, ) @@ -256,14 +256,13 @@ def _translate_batch_with_strategy(self, batch, decode_strategy): device=dec_in.device, ) dec_in = torch.cat((prepend_value, dec_in), dim=1) - src_max_len = src_len2.max() - src_pad_mask = sequence_mask(src_len2, src_max_len).unsqueeze( - 1 - ) # [B, 1, T_src] + src_pad_mask = src.eq(self._src_pad_idx).unsqueeze(1) # [B, 1, T_src] tgt_pad_mask = ( dec_in[:, :-1].eq(self._tgt_pad_idx).unsqueeze(1) ) # [B, 1, T_tgt] emb = self.model.tgt_emb(dec_in[:, :-1]) + self.model.decoder._disable_cache() + position_embeddings = self.model.rope.update(dec_in[:, :-1].size(1), step=0) dec_out, _ = self.model.decoder( emb, enc_out=enc_out2, @@ -272,6 +271,7 @@ def _translate_batch_with_strategy(self, batch, decode_strategy): return_attn=False, src_pad_mask=src_pad_mask, tgt_pad_mask=tgt_pad_mask, + position_embeddings=position_embeddings, ) pad_mask2 = ~dec_in[:, :-1].eq(self._tgt_pad_idx) in_estim2 = (dec_out * pad_mask2.unsqueeze(-1).float()).sum( @@ -304,7 +304,6 @@ def _score_target(self, batch, enc_out, src_len): log_probs, attn = self._decode_and_generate( tgt_in, enc_out, - batch, src_len=src_len, ) diff --git a/eole/tests/data/data_lm/gen-beam-sol.txt b/eole/tests/data/data_lm/gen-beam-sol.txt index e6a2656b..e4f93f80 100644 --- a/eole/tests/data/data_lm/gen-beam-sol.txt +++ b/eole/tests/data/data_lm/gen-beam-sol.txt @@ -1,5 +1,5 @@ you ! -in German German Presidency in German Presidency in German Presidency in German Presidency in the Netherlands . +in German Presidency in German Presidency in German Presidency in German Presidency in the Netherlands . the future . " ignored . diff --git a/eole/tests/data/data_lm/gen-sampling-beams-sol2.txt b/eole/tests/data/data_lm/gen-sampling-beams-sol2.txt index c361bcc4..36da5aa9 100644 --- a/eole/tests/data/data_lm/gen-sampling-beams-sol2.txt +++ b/eole/tests/data/data_lm/gen-sampling-beams-sol2.txt @@ -1,7 +1,7 @@ -you ! Next to your luck . -inspired by the absolut well-beeing-feeling of the Tauern Spa and ... -the top of the top of the top of the moment , health of the importance of capital , the world . -" Austrian " sold with " Delta , received from twenty years . -administered by the usefulness of the interinstitutional coherence of the recorded by the present in the young majority of renewable energies . -do so would like this subsidy is requested . -800 m2 can 't be seen on payments . +you ! It is your interpretation , or a number of number of balance between Russia ? +in German Presidency in German Presidency in German Presidency in the Netherlands , the Netherlands Presidency in German Presidency Presidency where the Treaty of continuing in all Member State . +a moment , one century . +" s leading directly on the region " s economy - region . +Israel; and be implemented . +do they do so have any ideas for this matter . +it is possible to make your reservation whatsoever . diff --git a/eole/tests/data/data_lm/gen-sampling-sol.txt b/eole/tests/data/data_lm/gen-sampling-sol.txt index 33a12ae0..c03dc62f 100644 --- a/eole/tests/data/data_lm/gen-sampling-sol.txt +++ b/eole/tests/data/data_lm/gen-sampling-sol.txt @@ -1,7 +1,7 @@ you ! -in German Presidency in German Presidency in the Netherlands Presidency in this: only 13 countries in the Netherlands , in this: only 13 countries in Germany " s Treaty of 25 century . +in German Presidency in German Presidency in the Netherlands Presidency in this: only 13 countries in the Treaty of 25 century . the top of the top of the top of the top of the top of the top of the top of the top of the top of the top of the top of the top of the top of the future of the future . " s famous Italian Prime Minister said John Leahy , Airbus Chief Operating Officer , Customers . fine words is administered by India , South Tibet as India , South Tibet as India , South Tibet as claimed by China in China in the United States . -do not know what is doing here . -the hotel is a few weeks ago , the first class is a big government . +do not know what is doing so that is not want to do not have any other society . +the hotel is a large scale factor in the country of the country of the country of women 's democracy . diff --git a/eole/tests/pull_request_check.sh b/eole/tests/pull_request_check.sh index 0a9697af..057992d0 100755 --- a/eole/tests/pull_request_check.sh +++ b/eole/tests/pull_request_check.sh @@ -4,7 +4,7 @@ # SKIP_DOWNLOADS=1 If files/uncompressed dirs exist don't download (if compressed files exist, just untar). # SKIP_FULL_CLEAN=1 Don't remove anything downloaded/uncompressed. -SKIP_FULL_CLEAN=1 +SKIP_FULL_CLEAN=0 PROJECT_ROOT=`dirname "$0"`"/../.." DATA_DIR="$PROJECT_ROOT/eole/tests/data" @@ -26,14 +26,13 @@ clean_up() # rm ${LOG_FILE} # fi if [[ "${SKIP_FULL_CLEAN}" == "1" ]]; then - # delete any .pt's that weren't downloaded - ls $TMP_OUT_DIR/*.pt | xargs -I {} rm -f $TMP_OUT_DIR/{} + # delete any .model's that weren't downloaded + ls $TMP_OUT_DIR/*.model | xargs -I {} rm -rf $TMP_OUT_DIR/{} else - # delete all .pt's - rm -r $TMP_OUT_DIR/dump_pred - rm -f $TMP_OUT_DIR/*.pt - rm -rf $TMP_OUT_DIR/sample - rm -d $TMP_OUT_DIR + rm -rf $TMP_OUT_DIR/dump_pred + rm -rf $TMP_OUT_DIR/*.model + rm -rf $TMP_OUT_DIR/eole.train.check + rm -f $TMP_OUT_DIR/eole.vocab.* fi } trap clean_up SIGINT SIGQUIT SIGKILL @@ -52,18 +51,17 @@ error_exit() # } # black check -# echo -n "[+] Doing Black check..." -# ${PYTHON} -m black --check . >> ${LOG_FILE} 2>&1 -# [ "$?" -eq 0 ] || error_exit -# echo "Succeeded" | tee -a ${LOG_FILE} +echo -n "[+] Doing Black check..." +${PYTHON} -m black --check . >> ${LOG_FILE} 2>&1 +[ "$?" -eq 0 ] || error_exit +echo "Succeeded" | tee -a ${LOG_FILE} # flake8 check -# echo -n "[+] Doing flake8 check..." -# ${PYTHON} -m flake8 --ignore *venv* >> ${LOG_FILE} 2>&1 -# [ "$?" -eq 0 ] || error_exit -# echo "Succeeded" | tee -a ${LOG_FILE} - -# exit +echo -n "[+] Doing flake8 check..." +#${PYTHON} -m flake8 --ignore *venv* . >> ${LOG_FILE} 2>&1 +${PYTHON} -m flake8 . >> ${LOG_FILE} 2>&1 +[ "$?" -eq 0 ] || error_exit +echo "Succeeded" | tee -a ${LOG_FILE} # unittest echo -n "[+] Doing unittest test..." @@ -90,7 +88,7 @@ rm -f -r $TMP_OUT_DIR/sample # # Training test # -echo -n "[+] Testing NMT vocab? /transforms prepare..." +echo -n "[+] Testing architecture rnn sample dump..." ${PYTHON} eole/bin/main.py train \ -config ${DATA_DIR}/data.yaml \ -model '{"architecture": "rnn"}' \ @@ -107,7 +105,7 @@ echo "Succeeded" | tee -a ${LOG_FILE} echo "[+] Doing Training test..." -echo -n " [+] Testing NMT training..." +echo -n " [+] Testing architecture rnn training..." ${PYTHON} eole/bin/main.py train \ -config ${DATA_DIR}/data.yaml \ -src_vocab $TMP_OUT_DIR/eole.vocab.src \ @@ -124,7 +122,7 @@ ${PYTHON} eole/tests/test_events.py --logdir $TMP_OUT_DIR/logs_train -tensorboar echo "Succeeded" | tee -a ${LOG_FILE} rm -r $TMP_OUT_DIR/logs_train -echo -n " [+] Testing NMT training and validation..." +echo -n " [+] Testing architecture rnn training and validation..." ${PYTHON} eole/bin/main.py train \ -config ${DATA_DIR}/data.yaml \ -src_vocab $TMP_OUT_DIR/eole.vocab.src \ @@ -143,35 +141,35 @@ ${PYTHON} eole/tests/test_events.py --logdir $TMP_OUT_DIR/logs_train_and_valid - echo "Succeeded" | tee -a ${LOG_FILE} rm -r $TMP_OUT_DIR/logs_train_and_valid -echo -n " [+] Testing NMT training w/ align..." +echo -n " [+] Testing architecture rnn training w/ coverage..." ${PYTHON} eole/bin/main.py train \ - -config ${DATA_DIR}/align_data.yaml \ + -config ${DATA_DIR}/data.yaml \ -src_vocab $TMP_OUT_DIR/eole.vocab.src \ -tgt_vocab $TMP_OUT_DIR/eole.vocab.tgt \ -src_vocab_size 1000 \ -tgt_vocab_size 1000 \ - -model '{"layers": 4, "hidden_size": 16, "transformer_ff": 64, "embeddings": {"word_vec_size": 16, "position_encoding_type": None}, "encoder": {"encoder_type": "transformer", "heads": 2}, "decoder": {"decoder_type": "transformer", "lambda_align": 0.05, "alignment_layer": 2, "alignment_heads": 0, "heads": 2}}' \ + -model '{"architecture": "rnn", "hidden_size": 10, "embeddings": {"word_vec_size": 5, "position_encoding_type": None}, "decoder": {"coverage_attn": True, "lambda_coverage": 0.1}}' \ -training '{"batch_size": 10, "num_workers": 0, "bucket_size": 1024, "train_steps": 10}' \ -report_every 5 \ >> ${LOG_FILE} 2>&1 [ "$?" -eq 0 ] || error_exit echo "Succeeded" | tee -a ${LOG_FILE} -echo -n " [+] Testing NMT training w/ coverage..." +echo -n " [+] Testing architecture custom transformer training w/ align..." ${PYTHON} eole/bin/main.py train \ - -config ${DATA_DIR}/data.yaml \ + -config ${DATA_DIR}/align_data.yaml \ -src_vocab $TMP_OUT_DIR/eole.vocab.src \ -tgt_vocab $TMP_OUT_DIR/eole.vocab.tgt \ -src_vocab_size 1000 \ -tgt_vocab_size 1000 \ - -model '{"architecture": "rnn", "hidden_size": 10, "embeddings": {"word_vec_size": 5, "position_encoding_type": None}, "decoder": {"coverage_attn": True, "lambda_coverage": 0.1}}' \ + -model '{"layers": 4, "hidden_size": 16, "transformer_ff": 64, "embeddings": {"word_vec_size": 16, "position_encoding_type": None}, "encoder": {"encoder_type": "transformer", "heads": 2}, "decoder": {"decoder_type": "transformer", "lambda_align": 0.05, "alignment_layer": 2, "alignment_heads": 0, "heads": 2}}' \ -training '{"batch_size": 10, "num_workers": 0, "bucket_size": 1024, "train_steps": 10}' \ -report_every 5 \ >> ${LOG_FILE} 2>&1 [ "$?" -eq 0 ] || error_exit echo "Succeeded" | tee -a ${LOG_FILE} -echo -n " [+] Testing NMT transformer training w/ validation with dynamic scoring..." +echo -n " [+] Testing architecture custom transformer training w/ validation with dynamic scoring..." ${PYTHON} eole/bin/main.py train \ -config ${DATA_DIR}/data.yaml \ -src_vocab $TMP_OUT_DIR/eole.vocab.src \ @@ -192,7 +190,7 @@ ${PYTHON} eole/tests/test_events.py --logdir $TMP_OUT_DIR/logs_dynamic-scoring - echo "Succeeded" | tee -a ${LOG_FILE} rm -r $TMP_OUT_DIR/logs_dynamic-scoring -echo -n " [+] Testing NMT transformer training w/ validation with dynamic scoring and maxrelative ..." +echo -n " [+] Testing architecture transformer training w/ validation with dynamic scoring and maxrelative ..." ${PYTHON} eole/bin/main.py train \ -config ${DATA_DIR}/data.yaml \ -src_vocab $TMP_OUT_DIR/eole.vocab.src \ @@ -213,14 +211,14 @@ ${PYTHON} eole/tests/test_events.py --logdir $TMP_OUT_DIR/logs_dynamic-scoring_a echo "Succeeded" | tee -a ${LOG_FILE} rm -r $TMP_OUT_DIR/logs_dynamic-scoring_and_relative -echo -n " [+] Testing NMT transformer training w/ validation with dynamic scoring and rotary ..." +echo -n " [+] Testing architecture transformer training w/ validation with dynamic scoring and rotary ..." ${PYTHON} eole/bin/main.py train \ -config ${DATA_DIR}/data.yaml \ -src_vocab $TMP_OUT_DIR/eole.vocab.src \ -tgt_vocab $TMP_OUT_DIR/eole.vocab.tgt \ -src_vocab_size 1000 \ -tgt_vocab_size 1000 \ - -model '{"architecture": "transformer", "layers": 4, "hidden_size": 16, "transformer_ff": 64, "heads": 2, "embeddings": {"word_vec_size": 16, "position_encoding_type": "Rotary"}}' \ + -model '{"architecture": "transformer", "layers": 4, "hidden_size": 16, "transformer_ff": 64, "heads": 2, "encoder": {"encoder_type": "transformer"}, "embeddings": {"word_vec_size": 16, "position_encoding_type": "Rotary"}}' \ -training '{"batch_size": 10, "num_workers": 0, "bucket_size": 1024, "train_steps": 10, "valid_steps": 5}' \ -valid_metrics "BLEU" "TER" \ -report_every 2 \ @@ -228,13 +226,13 @@ ${PYTHON} eole/bin/main.py train \ -scoring_debug \ -dump_preds $TMP_OUT_DIR/dump_pred \ -tensorboard_log_dir $TMP_OUT_DIR/logs_dynamic-scoring_and_rotary >> ${LOG_FILE} 2>&1 - + ${PYTHON} eole/tests/test_events.py --logdir $TMP_OUT_DIR/logs_dynamic-scoring_and_rotary -tensorboard_checks valid_metrics [ "$?" -eq 0 ] || error_exit echo "Succeeded" | tee -a ${LOG_FILE} rm -r $TMP_OUT_DIR/logs_dynamic-scoring_and_rotary -echo -n " [+] Testing NMT transformer training w/ validation with dynamic scoring and alibi ..." +echo -n " [+] Testing architecture transformer training w/ validation with dynamic scoring and alibi ..." ${PYTHON} eole/bin/main.py train \ -config ${DATA_DIR}/data.yaml \ -src_vocab $TMP_OUT_DIR/eole.vocab.src \ @@ -255,7 +253,7 @@ ${PYTHON} eole/tests/test_events.py --logdir $TMP_OUT_DIR/logs_dynamic-scoring_a echo "Succeeded" | tee -a ${LOG_FILE} rm -r $TMP_OUT_DIR/logs_dynamic-scoring_and_alibi -echo -n " [+] Testing LM training..." +echo -n " [+] Testing architecture custom decoder only training..." ${PYTHON} eole/bin/main.py train \ -config ${DATA_DIR}/lm_data.yaml \ -src_vocab $TMP_OUT_DIR/eole.vocab.src \ @@ -270,66 +268,19 @@ ${PYTHON} eole/bin/main.py train \ [ "$?" -eq 0 ] || error_exit echo "Succeeded" | tee -a ${LOG_FILE} -echo -n " [+] Testing Checkpoint Vocabulary Update..." -${PYTHON} eole/bin/main.py train \ - -config ${DATA_DIR}/data.yaml \ - -src_vocab $TMP_OUT_DIR/eole.vocab.src \ - -tgt_vocab $TMP_OUT_DIR/eole.vocab.tgt \ - -src_vocab_size 1000 -tgt_vocab_size 1000 \ - -model '{"architecture": "rnn", "hidden_size": 10, "embeddings": {"word_vec_size": 5, "position_encoding_type": None}}' \ - -training '{"batch_size": 10, "num_workers": 0, "bucket_size": 1024, "train_steps": 10, "model_path": "'"$TMP_OUT_DIR"'/eole.model", "save_checkpoint_steps": 10}' \ - -report_every 5 \ - >> ${LOG_FILE} 2>&1 -sed -i '1s/^/new_tok\t100000000\n/' $TMP_OUT_DIR/eole.vocab.src >> ${LOG_FILE} 2>&1 -${PYTHON} eole/bin/main.py train \ - -config ${DATA_DIR}/data.yaml \ - -src_vocab $TMP_OUT_DIR/eole.vocab.src \ - -tgt_vocab $TMP_OUT_DIR/eole.vocab.tgt \ - -src_vocab_size 1000 -tgt_vocab_size 1000 \ - -training '{"batch_size": 10, "num_workers": 0, "bucket_size": 1024, "train_steps": 20, "train_from": "'"$TMP_OUT_DIR"'/eole.model/step_10", "save_checkpoint_steps": 10, "update_vocab": True, "reset_optim": "states"}' \ - -report_every 5 \ - >> ${LOG_FILE} 2>&1 -[ "$?" -eq 0 ] || error_exit -echo "Succeeded" | tee -a ${LOG_FILE} - -echo -n " [+] Testing Checkpoint Vocabulary Update with LM..." -${PYTHON} eole/bin/main.py train \ - -config ${DATA_DIR}/lm_data.yaml \ - -src_vocab $TMP_OUT_DIR/eole.vocab.src \ - -tgt_vocab $TMP_OUT_DIR/eole.vocab.src \ - -model '{"layers": 2, "hidden_size": 16, "transformer_ff": 64, "embeddings": {"word_vec_size": 16}, "encoder": None, "decoder": {"decoder_type": "transformer_lm", "heads": 4}}' \ - -training '{"batch_size": 10, "num_workers": 0, "bucket_size": 1024, "train_steps": 10, "model_path": "'"$TMP_OUT_DIR"'/lm.eole.model", "save_checkpoint_steps": 10}' \ - -report_every 5 \ - -src_vocab_size 1000 \ - -tgt_vocab_size 1000 \ - >> ${LOG_FILE} 2>&1 -sed -i '1s/^/new_tok2\t100000000\n/' $TMP_OUT_DIR/eole.vocab.src >> ${LOG_FILE} 2>&1 -${PYTHON} eole/bin/main.py train \ - -config ${DATA_DIR}/lm_data.yaml \ - -src_vocab $TMP_OUT_DIR/eole.vocab.src \ - -tgt_vocab $TMP_OUT_DIR/eole.vocab.src \ - -model '{"layers": 2, "hidden_size": 16, "transformer_ff": 64, "embeddings": {"word_vec_size": 16}, "encoder": None, "decoder": {"decoder_type": "transformer_lm", "heads": 4}}' \ - -training '{"batch_size": 10, "num_workers": 0, "bucket_size": 1024, "train_steps": 20, "train_from": "'"$TMP_OUT_DIR"'/lm.eole.model/step_10", "save_checkpoint_steps": 10, "update_vocab": True, "reset_optim": "states"}' \ - -report_every 5 \ - -src_vocab_size 1000 \ - -tgt_vocab_size 1000 \ - >> ${LOG_FILE} 2>&1 -[ "$?" -eq 0 ] || error_exit -echo "Succeeded" | tee -a ${LOG_FILE} - # # Translation test # echo "[+] Doing translation test..." -echo -n " [+] Testing NMT translation..." +echo -n " [+] Testing RNN translation..." head ${DATA_DIR}/src-test.txt > $TMP_OUT_DIR/src-test.txt ${PYTHON} eole/bin/main.py predict -model_path ${TEST_DIR}/test_model -src $TMP_OUT_DIR/src-test.txt -verbose >> ${LOG_FILE} 2>&1 [ "$?" -eq 0 ] || error_exit echo "Succeeded" | tee -a ${LOG_FILE} rm $TMP_OUT_DIR/src-test.txt -echo -n " [+] Testing NMT ensemble translation..." +echo -n " [+] Testing RNN ensemble translation..." head ${DATA_DIR}/src-test.txt > $TMP_OUT_DIR/src-test.txt ${PYTHON} eole/bin/main.py predict -model_path ${TEST_DIR}/test_model ${TEST_DIR}/test_model \ -src $TMP_OUT_DIR/src-test.txt -verbose >> ${LOG_FILE} 2>&1 @@ -337,7 +288,7 @@ ${PYTHON} eole/bin/main.py predict -model_path ${TEST_DIR}/test_model ${TEST_DIR echo "Succeeded" | tee -a ${LOG_FILE} rm $TMP_OUT_DIR/src-test.txt -echo -n " [+] Testing NMT translation w/ Beam search..." +echo -n " [+] Testing RNN translation w/ Beam search..." ${PYTHON} eole/bin/main.py predict -model_path ${TEST_DIR}/test_model2 \ -src ${DATA_DIR}/morph/src.valid \ -verbose \ @@ -350,7 +301,7 @@ diff ${DATA_DIR}/morph/tgt.valid $TMP_OUT_DIR/trans_beam echo "Succeeded" | tee -a ${LOG_FILE} rm $TMP_OUT_DIR/trans_beam -echo -n " [+] Testing NMT translation w/ Random Sampling..." +echo -n " [+] Testing RNN translation w/ Random Sampling..." ${PYTHON} eole/bin/main.py predict -model_path ${TEST_DIR}/test_model2 \ -src ${DATA_DIR}/morph/src.valid \ -verbose -batch_size 10 \ @@ -366,7 +317,6 @@ echo "Succeeded" | tee -a ${LOG_FILE} rm $TMP_OUT_DIR/trans_sampling echo -n " [+] Testing LM generation..." -echo " [+] Testing LM generation..." | tee -a ${LOG_FILE} head ${DATA_DIR}/src-test.txt > $TMP_OUT_DIR/src-test.txt ${PYTHON} eole/bin/main.py predict -model_path ${TEST_DIR}/test_model_lm -src $TMP_OUT_DIR/src-test.txt -verbose >> ${LOG_FILE} 2>&1 [ "$?" -eq 0 ] || error_exit @@ -374,7 +324,6 @@ echo "Succeeded" | tee -a ${LOG_FILE} rm $TMP_OUT_DIR/src-test.txt echo -n " [+] Testing LM generation w/ Beam search..." -echo " [+] Testing LM generation w/ Beam search..." | tee -a ${LOG_FILE} ${PYTHON} eole/bin/main.py predict -model_path ${TEST_DIR}/test_model_lm \ -src ${DATA_DIR}/data_lm/src-gen.txt \ -verbose -batch_size 1 \ @@ -442,7 +391,6 @@ rm $TMP_OUT_DIR/gen_sampling # Inference engines test # echo -n " [+] Testing PY LM inference engine .." -echo " [+] Testing PY LM inference engine .."| tee -a ${LOG_FILE} head ${DATA_DIR}/src-test.txt > $TMP_OUT_DIR/src-test.txt ${PYTHON} eole/tests/test_inference_engines.py -model ${TEST_DIR}/test_model_lm \ -model_type decoder \ @@ -455,7 +403,7 @@ rm $TMP_OUT_DIR/src-test.txt rm $TMP_OUT_DIR/inference_engine_lm_py_outputs_file.json rm $TMP_OUT_DIR/inference_engine_lm_py_outputs_list.json -echo " [+] Testing CT2 LM inference engine .."| tee -a ${LOG_FILE} +echo -n " [+] Testing CT2 LM inference engine .." head ${DATA_DIR}/src-test.txt > $TMP_OUT_DIR/src-test.txt ${PYTHON} eole/tests/test_inference_engines.py -model ${TEST_DIR} \ -model_type decoder \ @@ -470,7 +418,6 @@ rm $TMP_OUT_DIR/inference_engine_lm_ct2_outputs_file.json rm $TMP_OUT_DIR/inference_engine_lm_ct2_outputs_list.json echo -n " [+] Testing PY SEQ2SEQ inference engine .." -echo " [+] Testing PY SEQ2SEQ inference engine .."| tee -a ${LOG_FILE} head ${DATA_DIR}/src-test.txt > $TMP_OUT_DIR/src-test.txt ${PYTHON} eole/tests/test_inference_engines.py -model ${TEST_DIR}/test_model \ -model_type encoder_decoder \ @@ -516,6 +463,53 @@ PYTHONPATH=${PROJECT_ROOT}:${PYTHONPATH} ${PYTHON} eole/bin/main.py model extrac [ "$?" -eq 0 ] || error_exit echo "Succeeded" | tee -a ${LOG_FILE} +echo -n " [+] Testing architecture rnn Checkpoint Vocabulary Update..." +${PYTHON} eole/bin/main.py train \ + -config ${DATA_DIR}/data.yaml \ + -src_vocab $TMP_OUT_DIR/eole.vocab.src \ + -tgt_vocab $TMP_OUT_DIR/eole.vocab.tgt \ + -src_vocab_size 1000 -tgt_vocab_size 1000 \ + -model '{"architecture": "rnn", "hidden_size": 10, "embeddings": {"word_vec_size": 5, "position_encoding_type": None}}' \ + -training '{"batch_size": 10, "num_workers": 0, "bucket_size": 1024, "train_steps": 10, "model_path": "'"$TMP_OUT_DIR"'/eole.model", "save_checkpoint_steps": 10}' \ + -report_every 5 \ + >> ${LOG_FILE} 2>&1 +sed -i '1s/^/new_tok\t100000000\n/' $TMP_OUT_DIR/eole.vocab.src >> ${LOG_FILE} 2>&1 +${PYTHON} eole/bin/main.py train \ + -config ${DATA_DIR}/data.yaml \ + -src_vocab $TMP_OUT_DIR/eole.vocab.src \ + -tgt_vocab $TMP_OUT_DIR/eole.vocab.tgt \ + -src_vocab_size 1000 -tgt_vocab_size 1000 \ + -training '{"batch_size": 10, "num_workers": 0, "bucket_size": 1024, "train_steps": 20, "train_from": "'"$TMP_OUT_DIR"'/eole.model/step_10", "save_checkpoint_steps": 10, "update_vocab": True, "reset_optim": "states"}' \ + -report_every 5 \ + >> ${LOG_FILE} 2>&1 +[ "$?" -eq 0 ] || error_exit +echo "Succeeded" | tee -a ${LOG_FILE} + +echo -n " [+] Testing Checkpoint Vocabulary Update with LM..." +${PYTHON} eole/bin/main.py train \ + -config ${DATA_DIR}/lm_data.yaml \ + -src_vocab $TMP_OUT_DIR/eole.vocab.src \ + -tgt_vocab $TMP_OUT_DIR/eole.vocab.src \ + -model '{"layers": 2, "hidden_size": 16, "transformer_ff": 64, "embeddings": {"word_vec_size": 16}, "encoder": None, "decoder": {"decoder_type": "transformer_lm", "heads": 4}}' \ + -training '{"batch_size": 10, "num_workers": 0, "bucket_size": 1024, "train_steps": 10, "model_path": "'"$TMP_OUT_DIR"'/lm.eole.model", "save_checkpoint_steps": 10}' \ + -report_every 5 \ + -src_vocab_size 1000 \ + -tgt_vocab_size 1000 \ + >> ${LOG_FILE} 2>&1 +sed -i '1s/^/new_tok2\t100000000\n/' $TMP_OUT_DIR/eole.vocab.src >> ${LOG_FILE} 2>&1 +${PYTHON} eole/bin/main.py train \ + -config ${DATA_DIR}/lm_data.yaml \ + -src_vocab $TMP_OUT_DIR/eole.vocab.src \ + -tgt_vocab $TMP_OUT_DIR/eole.vocab.src \ + -model '{"layers": 2, "hidden_size": 16, "transformer_ff": 64, "embeddings": {"word_vec_size": 16}, "encoder": None, "decoder": {"decoder_type": "transformer_lm", "heads": 4}}' \ + -training '{"batch_size": 10, "num_workers": 0, "bucket_size": 1024, "train_steps": 20, "train_from": "'"$TMP_OUT_DIR"'/lm.eole.model/step_10", "save_checkpoint_steps": 10, "update_vocab": True, "reset_optim": "states"}' \ + -report_every 5 \ + -src_vocab_size 1000 \ + -tgt_vocab_size 1000 \ + >> ${LOG_FILE} 2>&1 +[ "$?" -eq 0 ] || error_exit +echo "Succeeded" | tee -a ${LOG_FILE} + # Finally, clean up clean_up diff --git a/eole/tests/test_events.py b/eole/tests/test_events.py index 219d82a5..a456c151 100644 --- a/eole/tests/test_events.py +++ b/eole/tests/test_events.py @@ -47,5 +47,5 @@ def check_scalars(self, scalars, logdir): args = parser.parse_args() test_event = TestEvents() scalars = test_event.scalars[args.tensorboard_checks] - print("looking for scalars: ", scalars) + print("looking for scalars: ", scalars, end=" ") test_event.check_scalars(scalars, args.logdir) diff --git a/eole/tests/test_model.yml b/eole/tests/test_model.yml index a399e74e..b8222244 100644 --- a/eole/tests/test_model.yml +++ b/eole/tests/test_model.yml @@ -135,7 +135,6 @@ tgt_vocab_size: 1000 # dropout=[0.3], # attention_dropout=[0.1], # dropout_steps=[0], -# truncated_decoder=0, # adam_beta1=0.9, # adam_beta2=0.999, # label_smoothing=0.0, diff --git a/eole/tests/test_model_lm.yml b/eole/tests/test_model_lm.yml index d3f45b7b..c992677a 100644 --- a/eole/tests/test_model_lm.yml +++ b/eole/tests/test_model_lm.yml @@ -188,7 +188,6 @@ share_vocab: true # dropout=[0.1], # attention_dropout=[0.1], # dropout_steps=[0], -# truncated_decoder=0, # adam_beta1=0.9, # adam_beta2=0.998, # label_smoothing=0.1, diff --git a/eole/tests/test_model_lm/config.json b/eole/tests/test_model_lm/config.json index 57fb6cd2..5b8fd7d8 100644 --- a/eole/tests/test_model_lm/config.json +++ b/eole/tests/test_model_lm/config.json @@ -10,7 +10,7 @@ "layers": 2, "heads": 2 }, - "architecture": "transformer_lm", + "architecture": "custom", "encoder": null, "embeddings": { "tgt_word_vec_size": 64, diff --git a/eole/tests/test_models.py b/eole/tests/test_models.py index 68d66516..376723b0 100644 --- a/eole/tests/test_models.py +++ b/eole/tests/test_models.py @@ -8,7 +8,6 @@ import eole.inputters from eole.models.model import build_src_emb, build_tgt_emb, build_encoder, build_decoder -from eole.utils.misc import sequence_mask from eole.config.run import TrainConfig from eole.config.data import Dataset from eole.config.models import CustomModelConfig @@ -17,6 +16,14 @@ # but we can't because model building relies on some params # not currently in model config (dropout, freeze_word_vecs_enc, etc.) + +class NoOpPosition: + """A no-op position encoding callable.""" + + def update(self, *args, **kwargs): + return None + + opt = TrainConfig( data={ "dummy": Dataset(path_src="eole/tests/data/src-train.txt") @@ -32,6 +39,7 @@ def __init__(self, *args, **kwargs): super(TestModel, self).__init__(*args, **kwargs) self.opt = opt self.opt.training.self_attn_backend = "pytorch" + self.rope = NoOpPosition() def get_vocabs(self): src_vocab = pyonmttok.build_vocab_from_tokens( @@ -120,8 +128,9 @@ def encoder_forward(self, opt, source_l=3, bsize=1): test_src, test_tgt, test_length = self.get_batch(source_l=source_l, bsize=bsize) emb_src = embeddings(test_src) - mask = sequence_mask(test_length) - enc_out, hidden_t = enc(emb_src, mask) + pad_idx = embeddings.word_padding_idx + pad_mask = test_src.eq(pad_idx).unsqueeze(1) + enc_out, hidden_t = enc(emb_src, pad_mask=pad_mask) # Initialize vectors to compare size with test_hid = torch.zeros( @@ -166,6 +175,7 @@ def model_forward(self, opt, source_l=3, bsize=1): src_emb=src_emb, tgt_emb=tgt_emb, hidden_size=opt.model.decoder.hidden_size, + rope=NoOpPosition(), ) test_src, test_tgt, test_length = self.get_batch(source_l=source_l, bsize=bsize) output, attn, estim = model(test_src, test_tgt, test_length) diff --git a/eole/train_single.py b/eole/train_single.py index 03c7e45c..12a43979 100644 --- a/eole/train_single.py +++ b/eole/train_single.py @@ -2,7 +2,6 @@ """Training on a single process.""" import torch import sys - from eole.utils.logging import init_logger, logger from eole.config.run import TrainConfig from eole.constants import CorpusTask @@ -153,6 +152,12 @@ def main(config, device_id): init_logger(config.log_file) checkpoint, vocabs, transforms, config = _init_train(config) + # Allow only Memory Efficient path for sdpa + torch.backends.cuda.enable_mem_efficient_sdp(True) + torch.backends.cuda.enable_flash_sdp(False) + torch.backends.cuda.enable_math_sdp(False) + torch.backends.cuda.enable_cudnn_sdp(False) + # if transform + options set in 'valid' we need to copy in main # transform / options for scoring considered as inference validset_transforms = getattr(config.data.get("valid", None), "transforms", None) @@ -188,6 +193,7 @@ def main(config, device_id): ) if config.training.torch_compile: + torch._dynamo.config.cache_size_limit = 16 model = torch.compile(model, dynamic=True) model.count_parameters(log=logger.info) diff --git a/eole/trainer.py b/eole/trainer.py index 16f5cf16..e613f761 100644 --- a/eole/trainer.py +++ b/eole/trainer.py @@ -48,9 +48,7 @@ def build_trainer(config, device_id, model, vocabs, optim, model_saver=None): scoring_preparator.warm_up(validset_transforms) scorers_cls = get_scorers_cls(config.valid_metrics) valid_scorers = build_scorers(config, scorers_cls) - running_config = config.training - trunc_size = running_config.truncated_decoder # Badly named... norm_method = running_config.normalization accum_count = running_config.accum_count accum_steps = running_config.accum_steps @@ -86,7 +84,6 @@ def build_trainer(config, device_id, model, vocabs, optim, model_saver=None): scoring_preparator, valid_scorers, optim, - trunc_size, norm_method, accum_count, accum_steps, @@ -127,8 +124,6 @@ class Trainer(object): of the validation metrics optim(:obj:`eole.utils.optimizers.Optimizer`): the optimizer responsible for update - trunc_size(int): length of truncated back propagation - through time accum_count(list): accumulate gradients this many times. accum_steps(list): steps for accum gradients changes. n_gpu (int): number of gpu. @@ -160,7 +155,6 @@ def __init__( scoring_preparator, valid_scorers, optim, - trunc_size=0, norm_method="sents", accum_count=[1], accum_steps=[0], @@ -188,11 +182,9 @@ def __init__( self.estim_loss_lambda = estim_loss_lambda[0] self.estim_loss_lambda_steps = estim_loss_lambda_steps self.valid_loss = valid_loss - self.scoring_preparator = scoring_preparator self.valid_scorers = valid_scorers self.optim = optim - self.trunc_size = trunc_size self.norm_method = norm_method self.accum_count_l = accum_count self.accum_count = accum_count[0] @@ -469,6 +461,7 @@ def validate(self, valid_iter, moving_average=None): # Update statistics. stats.update(metric_stats) + valid_model.decoder._disable_cache() if moving_average: for param_data, param in zip(model_params_data, self.model.parameters()): @@ -487,80 +480,54 @@ def _gradient_accumulation( Perform a backward on the loss of each sub_batch and finally update the params at the end of the big batch.""" - if self.accum_count > 1: - self.optim.zero_grad(set_to_none=True) + self.optim.zero_grad(set_to_none=True) for k, batch in enumerate(true_batches): - target_size = batch["tgt"].size(1) - # Truncated BPTT: reminder not compatible with accum > 1 - if self.trunc_size: - trunc_size = self.trunc_size - else: - trunc_size = target_size src = batch["src"] src_len = batch["srclen"] if src_len is not None: report_stats.n_src_words += src_len.sum().item() total_stats.n_src_words += src_len.sum().item() + tgt = batch["tgt"] - tgt_outer = batch["tgt"] - - bptt = False - for j in range(0, target_size - 1, trunc_size): - # 1. Create truncated target. - - tgt = tgt_outer[:, j : j + trunc_size] - - # 2. F-prop all but generator. - if self.accum_count == 1: - self.optim.zero_grad(set_to_none=True) - try: - with get_autocast(enabled=self.optim.amp): - model_out, attns, estim = self.model( - src, tgt, src_len, bptt=bptt, with_align=self.with_align - ) - bptt = True - - # 3. Compute loss. - if self.zero_out_prompt_loss: - # The loss of the prompt will be set to zero. - batch = self.train_loss.ignore_prompt(batch) - loss, batch_stats, auxloss = self.train_loss( - batch, - model_out, - attns, - trunc_start=j, - trunc_size=trunc_size, - estim=estim, - ) - if loss is not None: - loss /= normalization - auxloss /= self.accum_count * src_len.size(0) - loss = loss + auxloss * self.estim_loss_lambda - self.optim.backward(loss) - - total_stats.update(batch_stats) - report_stats.update(batch_stats) - - except Exception as exc: - trace_content = traceback.format_exc() - if "CUDA out of memory" in trace_content: - logger.info( - "Step %d, cuda OOM - batch removed", - self.optim.training_step, - ) - clear_gpu_cache() - if self.n_gpu > 1 and self.parallel_mode == "tensor_parallel": - torch.distributed.destroy_process_group() - sys.exit() - else: - traceback.print_exc() - raise exc - - # If truncated, don't backprop fully. - if self.model.decoder is not None and self.model.decoder.state != {}: - self.model.decoder.detach_state() + try: + with get_autocast(enabled=self.optim.amp): + model_out, attns, estim = self.model( + src, tgt, src_len, with_align=self.with_align + ) + if self.zero_out_prompt_loss: + # The loss of the prompt will be set to zero. + batch = self.train_loss.ignore_prompt(batch) + loss, batch_stats, auxloss = self.train_loss( + batch, + model_out, + attns, + estim=estim, + ) + if loss is not None: + loss /= normalization + auxloss /= self.accum_count * src_len.size(0) + loss = loss + auxloss * self.estim_loss_lambda + self.optim.backward(loss) + + total_stats.update(batch_stats) + report_stats.update(batch_stats) + + except Exception as exc: + trace_content = traceback.format_exc() + if "CUDA out of memory" in trace_content: + logger.info( + "Step %d, cuda OOM - batch removed", + self.optim.training_step, + ) + clear_gpu_cache() + if self.n_gpu > 1 and self.parallel_mode == "tensor_parallel": + torch.distributed.destroy_process_group() + sys.exit() + else: + traceback.print_exc() + raise exc # in case of multi step gradient accumulation, # update only after accum batches diff --git a/eole/utils/loss.py b/eole/utils/loss.py index 8bb2c9b1..761062d5 100644 --- a/eole/utils/loss.py +++ b/eole/utils/loss.py @@ -270,13 +270,8 @@ def ignore_prompt(self, batch): batch["tgt"] += self.padding_idx * (1 - mask.int()) return batch - def forward(self, batch, output, attns, trunc_start=0, trunc_size=None, estim=None): - """Compute the forward loss, supports truncated BPTT for long - sequences by taking a range in the decoder output sequence to - back propagate in. - Range is from `(trunc_start, trunc_start + trunc_size)`. - Truncation is an approximate efficiency trick to relieve the - memory required in the RNN buffers. + def forward(self, batch, output, attns, estim=None): + """Compute the forward loss Args: batch (batch) : batch of labeled examples @@ -284,22 +279,12 @@ def forward(self, batch, output, attns, trunc_start=0, trunc_size=None, estim=No output of decoder model ``(batch, tgt_len, hidden)`` attns (dict) : dictionary of attention weights ``(batch, tgt_len, src_len)`` - trunc_start (int) : starting position of truncation window - trunc_size (int) : length of truncation window Returns: A tuple with the loss and a :obj:`eole.utils.Statistics` instance. """ - - if trunc_size is None: - trunc_size = batch["tgt"].size(1) - trunc_start # take into account here the tgt_shift_index (0 / 1 = LM/NMT) - trunc_range = (trunc_start + self.tgt_shift_index, trunc_start + trunc_size) - - target = batch["tgt"][:, trunc_range[0] : trunc_range[1]] - output = output[:, trunc_start : trunc_range[1], :].contiguous() - - flat_tgt = target[:, :].contiguous().view(-1) + flat_tgt = batch["tgt"][:, self.tgt_shift_index :].contiguous().view(-1) if self.generator is not None: scores = self.generator(self._bottle(output)) @@ -321,7 +306,7 @@ def forward(self, batch, output, attns, trunc_start=0, trunc_size=None, estim=No ref_align = eole.utils.make_batch_align_matrix( align_idx, align_matrix_size, normalize=True ) - ref_align = ref_align[:, trunc_range[0] : trunc_range[1], :] + ref_align = ref_align[:, self.tgt_shift_index :, :] if ref_align.dtype != loss.dtype: ref_align = ref_align.to(loss.dtype) align_loss = self._compute_alignement_loss( @@ -348,7 +333,7 @@ def forward(self, batch, output, attns, trunc_start=0, trunc_size=None, estim=No estimloss = self.estimloss(estim, batch["sco"]).to(estim.dtype) else: estimloss = torch.tensor([0.0], device=loss.device) - n_sents = len(batch["srclen"]) if trunc_start == 0 else 0 + n_sents = len(batch["srclen"]) stats = self._stats( n_sents, loss.sum().item(), estimloss.item(), scores, flat_tgt diff --git a/eole/utils/misc.py b/eole/utils/misc.py index 48510008..8d400a03 100644 --- a/eole/utils/misc.py +++ b/eole/utils/misc.py @@ -108,7 +108,7 @@ def get_autocast(enabled=True, device_type="auto"): device_type = get_device_type() if device_type == "cuda": - return torch.cuda.amp.autocast() + return torch.amp.autocast("cuda") elif device_type == "mps": return torch.amp.autocast(device_type="mps") else: diff --git a/eole/utils/optimizers.py b/eole/utils/optimizers.py index 166be4eb..f0e0d81d 100644 --- a/eole/utils/optimizers.py +++ b/eole/utils/optimizers.py @@ -393,9 +393,9 @@ def from_config(cls, model, config, checkpoint=None): optimizer._fp16 = "legacy" else: optimizer._fp16 = "amp" - from torch.cuda.amp import GradScaler + from torch.amp import GradScaler - optimizer._scaler = GradScaler() + optimizer._scaler = GradScaler("cuda") if optim_state_dict: optimizer.load_state_dict(optim_state_dict) return optimizer @@ -786,7 +786,7 @@ def step( # assuming a list/generator of parameter means single group elif isinstance(grads, types.GeneratorType): grads_group = [grads] - elif type(grads[0]) != list: + elif not isinstance(grads[0], list): grads_group = [grads] else: grads_group = grads @@ -795,7 +795,7 @@ def step( output_params_group = [None] * len(self.param_groups) elif isinstance(output_params, types.GeneratorType): output_params_group = [output_params] - elif type(output_params[0]) != list: + elif not isinstance(output_params[0], list): output_params_group = [output_params] else: output_params_group = output_params diff --git a/recipes/cometkiwi/cometkiwi-xl-eole.yaml b/recipes/cometkiwi/cometkiwi-xl-eole.yaml index 4dad7af0..88d9d0a6 100755 --- a/recipes/cometkiwi/cometkiwi-xl-eole.yaml +++ b/recipes/cometkiwi/cometkiwi-xl-eole.yaml @@ -118,12 +118,8 @@ model: add_qkvbias: true add_ffnbias: true mlp_activation_fn: gelu - #parallel_residual: true - #shared_layer_norm: true add_estimator: true share_decoder_embeddings: true - rope_config: - rotary_interleave: false layer_norm: standard norm_eps: 1e-5 embeddings: diff --git a/recipes/cometkiwi/cometkiwi-xxl-eole.yaml b/recipes/cometkiwi/cometkiwi-xxl-eole.yaml index f59a19d6..160efba4 100755 --- a/recipes/cometkiwi/cometkiwi-xxl-eole.yaml +++ b/recipes/cometkiwi/cometkiwi-xxl-eole.yaml @@ -116,12 +116,8 @@ model: add_qkvbias: true add_ffnbias: true mlp_activation_fn: gelu - #parallel_residual: true - #shared_layer_norm: true add_estimator: true share_decoder_embeddings: true - rope_config: - rotary_interleave: false layer_norm: standard norm_eps: 1e-5 embeddings: