From 628c5e53f2e15f1b84caad88ca2ae8c40668b297 Mon Sep 17 00:00:00 2001 From: malodetz Date: Fri, 15 Jul 2022 00:03:50 +0300 Subject: [PATCH 01/33] Added useful classes --- .../data/comment_path_context_data_module.py | 10 ++++++++++ code2seq/data/comment_path_context_dataset.py | 19 +++++++++++++++++++ 2 files changed, 29 insertions(+) create mode 100644 code2seq/data/comment_path_context_data_module.py create mode 100644 code2seq/data/comment_path_context_dataset.py diff --git a/code2seq/data/comment_path_context_data_module.py b/code2seq/data/comment_path_context_data_module.py new file mode 100644 index 0000000..7b6e270 --- /dev/null +++ b/code2seq/data/comment_path_context_data_module.py @@ -0,0 +1,10 @@ +from code2seq.data.comment_path_context_dataset import CommentPathContextDataset +from code2seq.data.path_context_data_module import PathContextDataModule + + +class CommentPathContextDataModule(PathContextDataModule): + + def _create_dataset(self, holdout_file: str, random_context: bool) -> CommentPathContextDataset: + if self._vocabulary is None: + raise RuntimeError(f"Setup vocabulary before creating data loaders") + return CommentPathContextDataset(holdout_file, self._config, self._vocabulary, random_context) \ No newline at end of file diff --git a/code2seq/data/comment_path_context_dataset.py b/code2seq/data/comment_path_context_dataset.py new file mode 100644 index 0000000..1b87885 --- /dev/null +++ b/code2seq/data/comment_path_context_dataset.py @@ -0,0 +1,19 @@ +from typing import Dict, List, Optional + +from code2seq.data.path_context_dataset import PathContextDataset +from code2seq.data.vocabulary import Vocabulary + + +class CommentPathContextDataset(PathContextDataset): + + @staticmethod + def tokenize_label(raw_label: str, vocab: Dict[str, int], max_parts: Optional[int]) -> List[int]: + sublabels = raw_label.split(PathContextDataset._separator) + max_parts = max_parts or len(sublabels) + label_unk = vocab[Vocabulary.UNK] + + label = [vocab[Vocabulary.SOS]] + [vocab.get(st, label_unk) for st in sublabels[:max_parts]] + if len(sublabels) < max_parts: + label.append(vocab[Vocabulary.EOS]) + label += [vocab[Vocabulary.PAD]] * (max_parts + 1 - len(label)) + return label From 1ff19948cd7caf59efb2f69956023f701ec4969a Mon Sep 17 00:00:00 2001 From: malodetz Date: Fri, 15 Jul 2022 01:17:33 +0300 Subject: [PATCH 02/33] Config added --- config/comment-code2seq-java-large.yaml | 59 +++++++++++++++++++++++++ 1 file changed, 59 insertions(+) create mode 100644 config/comment-code2seq-java-large.yaml diff --git a/config/comment-code2seq-java-large.yaml b/config/comment-code2seq-java-large.yaml new file mode 100644 index 0000000..2bfa23f --- /dev/null +++ b/config/comment-code2seq-java-large.yaml @@ -0,0 +1,59 @@ +data_folder: ./dataset + +checkpoint: null + +seed: 7 +# Training in notebooks (e.g. Google Colab) may crash with too small value +progress_bar_refresh_rate: 1 +print_config: true + +wandb: + project: Code2Seq -- java-med + group: null + offline: false + +data: + num_workers: 4 + + # Each token appears at least 10 times (99.2% coverage) + labels_count: 10 + max_label_parts: 256 + # Each token appears at least 1000 times (99.5% coverage) + tokens_count: 1000 + max_token_parts: 5 + path_length: 9 + + max_context: 200 + random_context: true + + batch_size: 512 + test_batch_size: 512 + +model: + # Encoder + embedding_size: 128 + encoder_dropout: 0.25 + encoder_rnn_size: 128 + use_bi_rnn: true + rnn_num_layers: 1 + + # Decoder + decoder_size: 320 + decoder_num_layers: 1 + rnn_dropout: 0.5 + +optimizer: + optimizer: "Momentum" + nesterov: true + lr: 0.01 + weight_decay: 0 + decay_gamma: 0.95 + +train: + n_epochs: 10 + patience: 10 + clip_norm: 5 + teacher_forcing: 1.0 + val_every_epoch: 1 + save_every_epoch: 1 + log_every_n_steps: 10 \ No newline at end of file From f8ffa007dd1d9bfcabdd079a85b798b4e4e5f496 Mon Sep 17 00:00:00 2001 From: malodetz Date: Fri, 15 Jul 2022 11:02:13 +0300 Subject: [PATCH 03/33] Added comment label processing --- code2seq/data/comment_path_context_dataset.py | 20 +++++++++++-------- requirements.txt | 8 ++++++-- 2 files changed, 18 insertions(+), 10 deletions(-) diff --git a/code2seq/data/comment_path_context_dataset.py b/code2seq/data/comment_path_context_dataset.py index 1b87885..d43b5c7 100644 --- a/code2seq/data/comment_path_context_dataset.py +++ b/code2seq/data/comment_path_context_dataset.py @@ -1,19 +1,23 @@ from typing import Dict, List, Optional +from omegaconf import DictConfig +from transformers import RobertaTokenizerFast + from code2seq.data.path_context_dataset import PathContextDataset from code2seq.data.vocabulary import Vocabulary class CommentPathContextDataset(PathContextDataset): + def __init__(self, data_file: str, config: DictConfig, vocabulary: Vocabulary, random_context: bool): + super().__init__(data_file, config, vocabulary, random_context) @staticmethod def tokenize_label(raw_label: str, vocab: Dict[str, int], max_parts: Optional[int]) -> List[int]: - sublabels = raw_label.split(PathContextDataset._separator) - max_parts = max_parts or len(sublabels) - label_unk = vocab[Vocabulary.UNK] + tokenizer = RobertaTokenizerFast.from_pretrained("microsoft/codebert-base") + + label_with_spaces = ' '.join(raw_label.split(PathContextDataset._separator)) + label_tokens = [tokenizer.bos_token] + tokenizer.tokenize(label_with_spaces)[:max_parts - 2] + [ + tokenizer.eos_token] + label_tokens += [tokenizer.pad_token] * (max_parts - len(label_tokens)) - label = [vocab[Vocabulary.SOS]] + [vocab.get(st, label_unk) for st in sublabels[:max_parts]] - if len(sublabels) < max_parts: - label.append(vocab[Vocabulary.EOS]) - label += [vocab[Vocabulary.PAD]] * (max_parts + 1 - len(label)) - return label + return tokenizer.convert_tokens_to_ids(label_tokens) diff --git a/requirements.txt b/requirements.txt index bad2311..1529943 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,11 @@ -torch==1.10.0 +torch~=1.9.0 pytorch-lightning==1.5.1 -torchmetrics==0.6.0 +torchmetrics~=0.5.0 tqdm==4.62.3 wandb==0.12.6 omegaconf==2.1.1 commode-utils==0.4.1 + +typing~=3.10.0.0 +transformers~=4.18.0 +setuptools~=50.3.1 \ No newline at end of file From 5f39e68b25324bd5eced1a9938d4466150e1868c Mon Sep 17 00:00:00 2001 From: malodetz Date: Fri, 15 Jul 2022 11:23:21 +0300 Subject: [PATCH 04/33] Added wrapper --- code2seq/comment_code2seq_wrapper.py | 61 +++++++++++++++++++ code2seq/data/comment_path_context_dataset.py | 4 -- 2 files changed, 61 insertions(+), 4 deletions(-) create mode 100644 code2seq/comment_code2seq_wrapper.py diff --git a/code2seq/comment_code2seq_wrapper.py b/code2seq/comment_code2seq_wrapper.py new file mode 100644 index 0000000..4080570 --- /dev/null +++ b/code2seq/comment_code2seq_wrapper.py @@ -0,0 +1,61 @@ +from argparse import ArgumentParser +from typing import cast + +import torch +from commode_utils.common import print_config +from omegaconf import DictConfig, OmegaConf +from pytorch_lightning import Trainer + +from code2seq.data.comment_path_context_data_module import CommentPathContextDataModule +from code2seq.model import Code2Seq +from code2seq.utils.common import filter_warnings +from code2seq.utils.test import test +from code2seq.utils.train import train + + +def configure_arg_parser() -> ArgumentParser: + arg_parser = ArgumentParser() + arg_parser.add_argument("mode", help="Mode to run script", choices=["train", "test"]) + arg_parser.add_argument("-c", "--config", help="Path to YAML configuration file", type=str) + return arg_parser + + +def train_code2seq(config: DictConfig): + filter_warnings() + + if config.print_config: + print_config(config, fields=["model", "data", "train", "optimizer"]) + + # Load data module + data_module = CommentPathContextDataModule(config.data_folder, config.data) + + # Load model + code2seq = Code2Seq(config.model, config.optimizer, data_module.vocabulary, config.train.teacher_forcing) + + train(code2seq, data_module, config) + + +def test_code2seq(config: DictConfig): + filter_warnings() + + # Load data module + data_module = CommentPathContextDataModule(config.data_folder, config.data) + + # Load model + code2seq = Code2Seq.load_from_checkpoint(config.checkpoint, map_location=torch.device("cpu")) + + test(code2seq, data_module, config.seed) + + trainer = Trainer() + print(trainer.predict(model=code2seq, datamodule=data_module, return_predictions=True)) + + +if __name__ == "__main__": + __arg_parser = configure_arg_parser() + __args = __arg_parser.parse_args() + + __config = cast(DictConfig, OmegaConf.load(__args.config)) + if __args.mode == "train": + train_code2seq(__config) + else: + test_code2seq(__config) diff --git a/code2seq/data/comment_path_context_dataset.py b/code2seq/data/comment_path_context_dataset.py index d43b5c7..5f65010 100644 --- a/code2seq/data/comment_path_context_dataset.py +++ b/code2seq/data/comment_path_context_dataset.py @@ -1,15 +1,11 @@ from typing import Dict, List, Optional -from omegaconf import DictConfig from transformers import RobertaTokenizerFast from code2seq.data.path_context_dataset import PathContextDataset -from code2seq.data.vocabulary import Vocabulary class CommentPathContextDataset(PathContextDataset): - def __init__(self, data_file: str, config: DictConfig, vocabulary: Vocabulary, random_context: bool): - super().__init__(data_file, config, vocabulary, random_context) @staticmethod def tokenize_label(raw_label: str, vocab: Dict[str, int], max_parts: Optional[int]) -> List[int]: From 86204145fecdcc970ceea00c9f6be7d204362b19 Mon Sep 17 00:00:00 2001 From: malodetz Date: Fri, 15 Jul 2022 11:47:42 +0300 Subject: [PATCH 05/33] Update requirements.txt --- requirements.txt | 21 ++++++++++----------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/requirements.txt b/requirements.txt index 1529943..3798bbd 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,11 +1,10 @@ -torch~=1.9.0 -pytorch-lightning==1.5.1 -torchmetrics~=0.5.0 -tqdm==4.62.3 -wandb==0.12.6 -omegaconf==2.1.1 -commode-utils==0.4.1 - -typing~=3.10.0.0 -transformers~=4.18.0 -setuptools~=50.3.1 \ No newline at end of file +torch==1.12.0 +pytorch-lightning==1.6.5 +torchmetrics==0.9.2 +tqdm==4.64.0 +wandb==0.12.21 +omegaconf==2.2.2 +commode-utils==0.4.2 +typing==3.7.4.3 +transformers==4.20.1 +setuptools==63.2.0 \ No newline at end of file From 7624d44e4ea9136c186a1a6e94d53a33d7af4910 Mon Sep 17 00:00:00 2001 From: malodetz Date: Fri, 15 Jul 2022 12:45:30 +0300 Subject: [PATCH 06/33] Fixing train to use gpu --- code2seq/utils/train.py | 1 + config/comment-code2seq-java-large.yaml | 5 +++-- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/code2seq/utils/train.py b/code2seq/utils/train.py index 4cca70d..d9940f2 100644 --- a/code2seq/utils/train.py +++ b/code2seq/utils/train.py @@ -44,6 +44,7 @@ def train(model: LightningModule, data_module: LightningDataModule, config: Dict log_every_n_steps=params.log_every_n_steps, logger=wandb_logger, gpus=params.gpu, + auto_select_gpus=True, callbacks=[lr_logger, early_stopping_callback, checkpoint_callback, print_epoch_result_callback, progress_bar], resume_from_checkpoint=config.get("checkpoint", None), ) diff --git a/config/comment-code2seq-java-large.yaml b/config/comment-code2seq-java-large.yaml index 2bfa23f..7e9483f 100644 --- a/config/comment-code2seq-java-large.yaml +++ b/config/comment-code2seq-java-large.yaml @@ -26,8 +26,8 @@ data: max_context: 200 random_context: true - batch_size: 512 - test_batch_size: 512 + batch_size: 16 + test_batch_size: 8 model: # Encoder @@ -50,6 +50,7 @@ optimizer: decay_gamma: 0.95 train: + gpu: 1 n_epochs: 10 patience: 10 clip_norm: 5 From fb5fb266966719a8791920e1b86bdcdfd71dfd75 Mon Sep 17 00:00:00 2001 From: malodetz Date: Fri, 15 Jul 2022 12:45:57 +0300 Subject: [PATCH 07/33] New model with correct decoder --- code2seq/comment_code2seq_wrapper.py | 6 +-- code2seq/model/comment_code2seq.py | 73 ++++++++++++++++++++++++++++ 2 files changed, 76 insertions(+), 3 deletions(-) create mode 100644 code2seq/model/comment_code2seq.py diff --git a/code2seq/comment_code2seq_wrapper.py b/code2seq/comment_code2seq_wrapper.py index 4080570..e4ec726 100644 --- a/code2seq/comment_code2seq_wrapper.py +++ b/code2seq/comment_code2seq_wrapper.py @@ -7,7 +7,7 @@ from pytorch_lightning import Trainer from code2seq.data.comment_path_context_data_module import CommentPathContextDataModule -from code2seq.model import Code2Seq +from code2seq.model.comment_code2seq import CommentCode2Seq from code2seq.utils.common import filter_warnings from code2seq.utils.test import test from code2seq.utils.train import train @@ -30,7 +30,7 @@ def train_code2seq(config: DictConfig): data_module = CommentPathContextDataModule(config.data_folder, config.data) # Load model - code2seq = Code2Seq(config.model, config.optimizer, data_module.vocabulary, config.train.teacher_forcing) + code2seq = CommentCode2Seq(config.model, config.optimizer, data_module.vocabulary, config.train.teacher_forcing) train(code2seq, data_module, config) @@ -42,7 +42,7 @@ def test_code2seq(config: DictConfig): data_module = CommentPathContextDataModule(config.data_folder, config.data) # Load model - code2seq = Code2Seq.load_from_checkpoint(config.checkpoint, map_location=torch.device("cpu")) + code2seq = CommentCode2Seq.load_from_checkpoint(config.checkpoint, map_location=torch.device("cpu")) test(code2seq, data_module, config.seed) diff --git a/code2seq/model/comment_code2seq.py b/code2seq/model/comment_code2seq.py new file mode 100644 index 0000000..a23fceb --- /dev/null +++ b/code2seq/model/comment_code2seq.py @@ -0,0 +1,73 @@ +from typing import Dict + +import torch +from commode_utils.losses import SequenceCrossEntropyLoss +from commode_utils.metrics import SequentialF1Score, ClassificationMetrics +from commode_utils.modules import LSTMDecoderStep, Decoder +from omegaconf import DictConfig +from pytorch_lightning import LightningModule +from pytorch_lightning.utilities.types import EPOCH_OUTPUT +from torchmetrics import MetricCollection, Metric +from transformers import RobertaTokenizerFast + +from code2seq.data.path_context import BatchedLabeledPathContext +from code2seq.data.vocabulary import Vocabulary +from code2seq.model import Code2Seq + + +class CommentCode2Seq(Code2Seq): + def __init__( + self, + model_config: DictConfig, + optimizer_config: DictConfig, + vocabulary: Vocabulary, + teacher_forcing: float = 0.0, + ): + super().__init__(model_config, optimizer_config, vocabulary, teacher_forcing) + + tokenizer = RobertaTokenizerFast.from_pretrained("microsoft/codebert-base") + self.__pad_idx = tokenizer.pad_token_id + eos_idx = tokenizer.eos_token_id + ignore_idx = [tokenizer.bos_token_id, tokenizer.unk_token_id] + metrics: Dict[str, Metric] = { + f"{holdout}_f1": SequentialF1Score(pad_idx=self.__pad_idx, eos_idx=eos_idx, ignore_idx=ignore_idx) + for holdout in ["train", "val", "test"] + } + #TODO add chrf back + self.__metrics = MetricCollection(metrics) + + self._encoder = self._get_encoder(model_config) + output_size = len(tokenizer.get_vocab()) + decoder_step = LSTMDecoderStep(model_config, output_size, self.__pad_idx) + self._decoder = Decoder(decoder_step, output_size, tokenizer.eos_token_id, teacher_forcing) + + self.__loss = SequenceCrossEntropyLoss(self.__pad_idx, reduction="batch-mean") + + def _shared_step(self, batch: BatchedLabeledPathContext, step: str) -> Dict: + target_sequence = batch.labels if step == "train" else None + # [seq length; batch size; vocab size] + logits, _ = self.logits_from_batch(batch, target_sequence) + result = {f"{step}/loss": self.__loss(logits[1:], batch.labels[1:])} + + with torch.no_grad(): + prediction = logits.argmax(-1) + metric: ClassificationMetrics = self.__metrics[f"{step}_f1"](prediction, batch.labels) + result.update( + {f"{step}/f1": metric.f1_score, f"{step}/precision": metric.precision, f"{step}/recall": metric.recall} + ) + + return result + + def _shared_epoch_end(self, step_outputs: EPOCH_OUTPUT, step: str): + with torch.no_grad(): + losses = [so if isinstance(so, torch.Tensor) else so["loss"] for so in step_outputs] + mean_loss = torch.stack(losses).mean() + metric = self.__metrics[f"{step}_f1"].compute() + log = { + f"{step}/loss": mean_loss, + f"{step}/f1": metric.f1_score, + f"{step}/precision": metric.precision, + f"{step}/recall": metric.recall, + } + self.__metrics[f"{step}_f1"].reset() + self.log_dict(log, on_step=False, on_epoch=True) From c4a203b4ebb7a821a40b5ccd26f95f3d8f28196a Mon Sep 17 00:00:00 2001 From: malodetz Date: Fri, 15 Jul 2022 12:58:07 +0300 Subject: [PATCH 08/33] Fix black --- code2seq/data/comment_path_context_data_module.py | 3 +-- code2seq/data/comment_path_context_dataset.py | 8 ++++---- code2seq/model/comment_code2seq.py | 13 ++++++------- code2seq/utils/optimization.py | 2 +- 4 files changed, 12 insertions(+), 14 deletions(-) diff --git a/code2seq/data/comment_path_context_data_module.py b/code2seq/data/comment_path_context_data_module.py index 7b6e270..2e40dc1 100644 --- a/code2seq/data/comment_path_context_data_module.py +++ b/code2seq/data/comment_path_context_data_module.py @@ -3,8 +3,7 @@ class CommentPathContextDataModule(PathContextDataModule): - def _create_dataset(self, holdout_file: str, random_context: bool) -> CommentPathContextDataset: if self._vocabulary is None: raise RuntimeError(f"Setup vocabulary before creating data loaders") - return CommentPathContextDataset(holdout_file, self._config, self._vocabulary, random_context) \ No newline at end of file + return CommentPathContextDataset(holdout_file, self._config, self._vocabulary, random_context) diff --git a/code2seq/data/comment_path_context_dataset.py b/code2seq/data/comment_path_context_dataset.py index 5f65010..a82ba07 100644 --- a/code2seq/data/comment_path_context_dataset.py +++ b/code2seq/data/comment_path_context_dataset.py @@ -6,14 +6,14 @@ class CommentPathContextDataset(PathContextDataset): - @staticmethod def tokenize_label(raw_label: str, vocab: Dict[str, int], max_parts: Optional[int]) -> List[int]: tokenizer = RobertaTokenizerFast.from_pretrained("microsoft/codebert-base") - label_with_spaces = ' '.join(raw_label.split(PathContextDataset._separator)) - label_tokens = [tokenizer.bos_token] + tokenizer.tokenize(label_with_spaces)[:max_parts - 2] + [ - tokenizer.eos_token] + label_with_spaces = " ".join(raw_label.split(PathContextDataset._separator)) + label_tokens = ( + [tokenizer.bos_token] + tokenizer.tokenize(label_with_spaces)[: max_parts - 2] + [tokenizer.eos_token] + ) label_tokens += [tokenizer.pad_token] * (max_parts - len(label_tokens)) return tokenizer.convert_tokens_to_ids(label_tokens) diff --git a/code2seq/model/comment_code2seq.py b/code2seq/model/comment_code2seq.py index a23fceb..5746d31 100644 --- a/code2seq/model/comment_code2seq.py +++ b/code2seq/model/comment_code2seq.py @@ -5,7 +5,6 @@ from commode_utils.metrics import SequentialF1Score, ClassificationMetrics from commode_utils.modules import LSTMDecoderStep, Decoder from omegaconf import DictConfig -from pytorch_lightning import LightningModule from pytorch_lightning.utilities.types import EPOCH_OUTPUT from torchmetrics import MetricCollection, Metric from transformers import RobertaTokenizerFast @@ -17,11 +16,11 @@ class CommentCode2Seq(Code2Seq): def __init__( - self, - model_config: DictConfig, - optimizer_config: DictConfig, - vocabulary: Vocabulary, - teacher_forcing: float = 0.0, + self, + model_config: DictConfig, + optimizer_config: DictConfig, + vocabulary: Vocabulary, + teacher_forcing: float = 0.0, ): super().__init__(model_config, optimizer_config, vocabulary, teacher_forcing) @@ -33,7 +32,7 @@ def __init__( f"{holdout}_f1": SequentialF1Score(pad_idx=self.__pad_idx, eos_idx=eos_idx, ignore_idx=ignore_idx) for holdout in ["train", "val", "test"] } - #TODO add chrf back + # TODO add chrf back self.__metrics = MetricCollection(metrics) self._encoder = self._get_encoder(model_config) diff --git a/code2seq/utils/optimization.py b/code2seq/utils/optimization.py index 7d3e501..7b3183a 100644 --- a/code2seq/utils/optimization.py +++ b/code2seq/utils/optimization.py @@ -29,5 +29,5 @@ def configure_optimizers_alon( optimizer = Adam(parameters, optim_config.lr, weight_decay=optim_config.weight_decay) else: raise ValueError(f"Unknown optimizer name: {optim_config.optimizer}, try one of: Adam, Momentum") - scheduler = LambdaLR(optimizer, lr_lambda=lambda epoch: optim_config.decay_gamma ** epoch) + scheduler = LambdaLR(optimizer, lr_lambda=lambda epoch: optim_config.decay_gamma**epoch) return [optimizer], [scheduler] From 6fdee0331b7648afa6e0d1d449842447b3effa64 Mon Sep 17 00:00:00 2001 From: malodetz Date: Sat, 16 Jul 2022 00:28:50 +0300 Subject: [PATCH 09/33] Added custom chrf metric --- code2seq/comment_code2seq_wrapper.py | 12 ++-- code2seq/model/code2seq.py | 40 ++++++------ code2seq/model/comment_code2seq.py | 91 ++++++++++++++++------------ requirements.txt | 3 +- 4 files changed, 81 insertions(+), 65 deletions(-) diff --git a/code2seq/comment_code2seq_wrapper.py b/code2seq/comment_code2seq_wrapper.py index e4ec726..9a75378 100644 --- a/code2seq/comment_code2seq_wrapper.py +++ b/code2seq/comment_code2seq_wrapper.py @@ -17,6 +17,8 @@ def configure_arg_parser() -> ArgumentParser: arg_parser = ArgumentParser() arg_parser.add_argument("mode", help="Mode to run script", choices=["train", "test"]) arg_parser.add_argument("-c", "--config", help="Path to YAML configuration file", type=str) + arg_parser.add_argument("-p", "--pretrained", help="Path to pretrained model", type=str, required=False, + default=None) return arg_parser @@ -35,20 +37,17 @@ def train_code2seq(config: DictConfig): train(code2seq, data_module, config) -def test_code2seq(config: DictConfig): +def test_code2seq(model_path: str, config: DictConfig): filter_warnings() # Load data module data_module = CommentPathContextDataModule(config.data_folder, config.data) # Load model - code2seq = CommentCode2Seq.load_from_checkpoint(config.checkpoint, map_location=torch.device("cpu")) + code2seq = CommentCode2Seq.load_from_checkpoint(model_path, map_location=torch.device("cpu")) test(code2seq, data_module, config.seed) - trainer = Trainer() - print(trainer.predict(model=code2seq, datamodule=data_module, return_predictions=True)) - if __name__ == "__main__": __arg_parser = configure_arg_parser() @@ -58,4 +57,5 @@ def test_code2seq(config: DictConfig): if __args.mode == "train": train_code2seq(__config) else: - test_code2seq(__config) + assert __args.pretrained is not None + test_code2seq(__args.pretrained, __config) diff --git a/code2seq/model/code2seq.py b/code2seq/model/code2seq.py index 603760d..c63b9c2 100644 --- a/code2seq/model/code2seq.py +++ b/code2seq/model/code2seq.py @@ -21,11 +21,11 @@ class Code2Seq(LightningModule): def __init__( - self, - model_config: DictConfig, - optimizer_config: DictConfig, - vocabulary: Vocabulary, - teacher_forcing: float = 0.0, + self, + model_config: DictConfig, + optimizer_config: DictConfig, + vocabulary: Vocabulary, + teacher_forcing: float = 0.0, ): super().__init__() self.save_hyperparameters() @@ -46,7 +46,7 @@ def __init__( metrics.update( {f"{holdout}_chrf": ChrF(id2label, ignore_idx + [self.__pad_idx, eos_idx]) for holdout in ["val", "test"]} ) - self.__metrics = MetricCollection(metrics) + self._metrics = MetricCollection(metrics) self._encoder = self._get_encoder(model_config) decoder_step = LSTMDecoderStep(model_config, len(vocabulary.label_to_id), self.__pad_idx) @@ -75,13 +75,13 @@ def configure_optimizers(self) -> Tuple[List[Optimizer], List[_LRScheduler]]: return configure_optimizers_alon(self._optim_config, self.parameters()) def forward( # type: ignore - self, - from_token: torch.Tensor, - path_nodes: torch.Tensor, - to_token: torch.Tensor, - contexts_per_label: torch.Tensor, - output_length: int, - target_sequence: torch.Tensor = None, + self, + from_token: torch.Tensor, + path_nodes: torch.Tensor, + to_token: torch.Tensor, + contexts_per_label: torch.Tensor, + output_length: int, + target_sequence: torch.Tensor = None, ) -> Tuple[torch.Tensor, torch.Tensor]: encoded_paths = self._encoder(from_token, path_nodes, to_token) output_logits, attention_weights = self._decoder( @@ -92,7 +92,7 @@ def forward( # type: ignore # ========== Model step ========== def logits_from_batch( - self, batch: BatchedLabeledPathContext, target_sequence: Optional[torch.Tensor] = None + self, batch: BatchedLabeledPathContext, target_sequence: Optional[torch.Tensor] = None ) -> Tuple[torch.Tensor, torch.Tensor]: return self( batch.from_token, @@ -111,12 +111,12 @@ def _shared_step(self, batch: BatchedLabeledPathContext, step: str) -> Dict: with torch.no_grad(): prediction = logits.argmax(-1) - metric: ClassificationMetrics = self.__metrics[f"{step}_f1"](prediction, batch.labels) + metric: ClassificationMetrics = self._metrics[f"{step}_f1"](prediction, batch.labels) result.update( {f"{step}/f1": metric.f1_score, f"{step}/precision": metric.precision, f"{step}/recall": metric.recall} ) if step != "train": - result[f"{step}/chrf"] = self.__metrics[f"{step}_chrf"](prediction, batch.labels) + result[f"{step}/chrf"] = self._metrics[f"{step}_chrf"](prediction, batch.labels) return result @@ -140,17 +140,17 @@ def _shared_epoch_end(self, step_outputs: EPOCH_OUTPUT, step: str): with torch.no_grad(): losses = [so if isinstance(so, torch.Tensor) else so["loss"] for so in step_outputs] mean_loss = torch.stack(losses).mean() - metric = self.__metrics[f"{step}_f1"].compute() + metric = self._metrics[f"{step}_f1"].compute() log = { f"{step}/loss": mean_loss, f"{step}/f1": metric.f1_score, f"{step}/precision": metric.precision, f"{step}/recall": metric.recall, } - self.__metrics[f"{step}_f1"].reset() + self._metrics[f"{step}_f1"].reset() if step != "train": - log[f"{step}/chrf"] = self.__metrics[f"{step}_chrf"].compute() - self.__metrics[f"{step}_chrf"].reset() + log[f"{step}/chrf"] = self._metrics[f"{step}_chrf"].compute() + self._metrics[f"{step}_chrf"].reset() self.log_dict(log, on_step=False, on_epoch=True) def training_epoch_end(self, step_outputs: EPOCH_OUTPUT): diff --git a/code2seq/model/comment_code2seq.py b/code2seq/model/comment_code2seq.py index 5746d31..5faa4c2 100644 --- a/code2seq/model/comment_code2seq.py +++ b/code2seq/model/comment_code2seq.py @@ -2,29 +2,29 @@ import torch from commode_utils.losses import SequenceCrossEntropyLoss -from commode_utils.metrics import SequentialF1Score, ClassificationMetrics +from commode_utils.metrics import SequentialF1Score from commode_utils.modules import LSTMDecoderStep, Decoder from omegaconf import DictConfig -from pytorch_lightning.utilities.types import EPOCH_OUTPUT +from sacrebleu import CHRF from torchmetrics import MetricCollection, Metric from transformers import RobertaTokenizerFast -from code2seq.data.path_context import BatchedLabeledPathContext from code2seq.data.vocabulary import Vocabulary from code2seq.model import Code2Seq class CommentCode2Seq(Code2Seq): def __init__( - self, - model_config: DictConfig, - optimizer_config: DictConfig, - vocabulary: Vocabulary, - teacher_forcing: float = 0.0, + self, + model_config: DictConfig, + optimizer_config: DictConfig, + vocabulary: Vocabulary, + teacher_forcing: float = 0.0, ): super().__init__(model_config, optimizer_config, vocabulary, teacher_forcing) tokenizer = RobertaTokenizerFast.from_pretrained("microsoft/codebert-base") + self.__pad_idx = tokenizer.pad_token_id eos_idx = tokenizer.eos_token_id ignore_idx = [tokenizer.bos_token_id, tokenizer.unk_token_id] @@ -32,8 +32,12 @@ def __init__( f"{holdout}_f1": SequentialF1Score(pad_idx=self.__pad_idx, eos_idx=eos_idx, ignore_idx=ignore_idx) for holdout in ["train", "val", "test"] } - # TODO add chrf back - self.__metrics = MetricCollection(metrics) + + # TODO add concatenation and rouge-L metric + metrics.update( + {f"{holdout}_chrf": CommentChrF(tokenizer) for holdout in ["val", "test"]} + ) + self._metrics = MetricCollection(metrics) self._encoder = self._get_encoder(model_config) output_size = len(tokenizer.get_vocab()) @@ -42,31 +46,42 @@ def __init__( self.__loss = SequenceCrossEntropyLoss(self.__pad_idx, reduction="batch-mean") - def _shared_step(self, batch: BatchedLabeledPathContext, step: str) -> Dict: - target_sequence = batch.labels if step == "train" else None - # [seq length; batch size; vocab size] - logits, _ = self.logits_from_batch(batch, target_sequence) - result = {f"{step}/loss": self.__loss(logits[1:], batch.labels[1:])} - - with torch.no_grad(): - prediction = logits.argmax(-1) - metric: ClassificationMetrics = self.__metrics[f"{step}_f1"](prediction, batch.labels) - result.update( - {f"{step}/f1": metric.f1_score, f"{step}/precision": metric.precision, f"{step}/recall": metric.recall} - ) - - return result - - def _shared_epoch_end(self, step_outputs: EPOCH_OUTPUT, step: str): - with torch.no_grad(): - losses = [so if isinstance(so, torch.Tensor) else so["loss"] for so in step_outputs] - mean_loss = torch.stack(losses).mean() - metric = self.__metrics[f"{step}_f1"].compute() - log = { - f"{step}/loss": mean_loss, - f"{step}/f1": metric.f1_score, - f"{step}/precision": metric.precision, - f"{step}/recall": metric.recall, - } - self.__metrics[f"{step}_f1"].reset() - self.log_dict(log, on_step=False, on_epoch=True) + +class CommentChrF(Metric): + def __init__(self, tokenizer: RobertaTokenizerFast, **kwargs): + super().__init__(**kwargs) + self.__tokenizer = tokenizer + self.__chrf = CHRF() + + # Metric states + self.add_state("chrf", default=torch.tensor(0, dtype=torch.float), dist_reduce_fx="sum") + self.add_state("count", default=torch.tensor(0), dist_reduce_fx="sum") + + def update(self, predicted: torch.Tensor, target: torch.Tensor): + """Calculated ChrF metric on predicted tensor w.r.t. target tensor. + + :param predicted: [pred seq len; batch size] -- tensor with predicted tokens + :param target: [target seq len; batch size] -- tensor with ground truth tokens + :return: + """ + batch_size = target.shape[1] + if predicted.shape[1] != batch_size: + raise ValueError(f"Wrong batch size for prediction (expected: {batch_size}, actual: {predicted.shape[1]})") + + for batch_idx in range(batch_size): + target_seq = [token.item() for token in target[:, batch_idx]] + predicted_seq = [token.item() for token in predicted[:, batch_idx]] + + target_str = " ".join(self.__tokenizer.convert_ids_to_tokens(target_seq, skip_special_tokens=True)) + predicted_str = " ".join(self.__tokenizer.convert_ids_to_tokens(predicted_seq, skip_special_tokens=True)) + print(target_str, "||", predicted_str) + + if target_str == "": + # Empty target string mean that the original string encoded only with token + continue + + self.chrf += self.__chrf.sentence_score(predicted_str, [target_str]).score + self.count += 1 + + def compute(self) -> torch.Tensor: + return self.chrf / self.count diff --git a/requirements.txt b/requirements.txt index 3798bbd..83caa6f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,4 +7,5 @@ omegaconf==2.2.2 commode-utils==0.4.2 typing==3.7.4.3 transformers==4.20.1 -setuptools==63.2.0 \ No newline at end of file +setuptools==63.2.0 +sacrebleu>=2.0.0 \ No newline at end of file From db702700979abbd0e51d7a4698e1cbd0c155404d Mon Sep 17 00:00:00 2001 From: malodetz Date: Sat, 16 Jul 2022 18:11:41 +0300 Subject: [PATCH 10/33] Minor updates --- code2seq/data/comment_path_context_dataset.py | 7 ++++--- code2seq/model/comment_code2seq.py | 1 - config/comment-code2seq-java-large.yaml | 6 +++--- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/code2seq/data/comment_path_context_dataset.py b/code2seq/data/comment_path_context_dataset.py index a82ba07..ba16dac 100644 --- a/code2seq/data/comment_path_context_dataset.py +++ b/code2seq/data/comment_path_context_dataset.py @@ -11,9 +11,10 @@ def tokenize_label(raw_label: str, vocab: Dict[str, int], max_parts: Optional[in tokenizer = RobertaTokenizerFast.from_pretrained("microsoft/codebert-base") label_with_spaces = " ".join(raw_label.split(PathContextDataset._separator)) - label_tokens = ( - [tokenizer.bos_token] + tokenizer.tokenize(label_with_spaces)[: max_parts - 2] + [tokenizer.eos_token] - ) + label_tokens = tokenizer.tokenize(label_with_spaces) + if max_parts is None: + max_parts = len(label_tokens) + label_tokens = [tokenizer.bos_token] + label_tokens[: max_parts - 2] + [tokenizer.eos_token] label_tokens += [tokenizer.pad_token] * (max_parts - len(label_tokens)) return tokenizer.convert_tokens_to_ids(label_tokens) diff --git a/code2seq/model/comment_code2seq.py b/code2seq/model/comment_code2seq.py index 5faa4c2..d713b92 100644 --- a/code2seq/model/comment_code2seq.py +++ b/code2seq/model/comment_code2seq.py @@ -74,7 +74,6 @@ def update(self, predicted: torch.Tensor, target: torch.Tensor): target_str = " ".join(self.__tokenizer.convert_ids_to_tokens(target_seq, skip_special_tokens=True)) predicted_str = " ".join(self.__tokenizer.convert_ids_to_tokens(predicted_seq, skip_special_tokens=True)) - print(target_str, "||", predicted_str) if target_str == "": # Empty target string mean that the original string encoded only with token diff --git a/config/comment-code2seq-java-large.yaml b/config/comment-code2seq-java-large.yaml index 7e9483f..4169c2c 100644 --- a/config/comment-code2seq-java-large.yaml +++ b/config/comment-code2seq-java-large.yaml @@ -19,14 +19,14 @@ data: labels_count: 10 max_label_parts: 256 # Each token appears at least 1000 times (99.5% coverage) - tokens_count: 1000 + tokens_count: 1 max_token_parts: 5 path_length: 9 max_context: 200 random_context: true - batch_size: 16 + batch_size: 8 test_batch_size: 8 model: @@ -51,7 +51,7 @@ optimizer: train: gpu: 1 - n_epochs: 10 + n_epochs: 100 patience: 10 clip_norm: 5 teacher_forcing: 1.0 From b68cc108e219d150d67342d9c9bdc6f30e4278d7 Mon Sep 17 00:00:00 2001 From: malodetz Date: Sun, 17 Jul 2022 00:38:47 +0300 Subject: [PATCH 11/33] Fix random --- code2seq/comment_code2seq_wrapper.py | 8 +++++--- code2seq/model/code2seq.py | 26 +++++++++++++------------- code2seq/model/comment_code2seq.py | 14 ++++++-------- 3 files changed, 24 insertions(+), 24 deletions(-) diff --git a/code2seq/comment_code2seq_wrapper.py b/code2seq/comment_code2seq_wrapper.py index 9a75378..b9b0a09 100644 --- a/code2seq/comment_code2seq_wrapper.py +++ b/code2seq/comment_code2seq_wrapper.py @@ -4,7 +4,7 @@ import torch from commode_utils.common import print_config from omegaconf import DictConfig, OmegaConf -from pytorch_lightning import Trainer +from pytorch_lightning import seed_everything from code2seq.data.comment_path_context_data_module import CommentPathContextDataModule from code2seq.model.comment_code2seq import CommentCode2Seq @@ -17,8 +17,9 @@ def configure_arg_parser() -> ArgumentParser: arg_parser = ArgumentParser() arg_parser.add_argument("mode", help="Mode to run script", choices=["train", "test"]) arg_parser.add_argument("-c", "--config", help="Path to YAML configuration file", type=str) - arg_parser.add_argument("-p", "--pretrained", help="Path to pretrained model", type=str, required=False, - default=None) + arg_parser.add_argument( + "-p", "--pretrained", help="Path to pretrained model", type=str, required=False, default=None + ) return arg_parser @@ -54,6 +55,7 @@ def test_code2seq(model_path: str, config: DictConfig): __args = __arg_parser.parse_args() __config = cast(DictConfig, OmegaConf.load(__args.config)) + seed_everything(__config.seed) if __args.mode == "train": train_code2seq(__config) else: diff --git a/code2seq/model/code2seq.py b/code2seq/model/code2seq.py index c63b9c2..1ac3963 100644 --- a/code2seq/model/code2seq.py +++ b/code2seq/model/code2seq.py @@ -21,11 +21,11 @@ class Code2Seq(LightningModule): def __init__( - self, - model_config: DictConfig, - optimizer_config: DictConfig, - vocabulary: Vocabulary, - teacher_forcing: float = 0.0, + self, + model_config: DictConfig, + optimizer_config: DictConfig, + vocabulary: Vocabulary, + teacher_forcing: float = 0.0, ): super().__init__() self.save_hyperparameters() @@ -75,13 +75,13 @@ def configure_optimizers(self) -> Tuple[List[Optimizer], List[_LRScheduler]]: return configure_optimizers_alon(self._optim_config, self.parameters()) def forward( # type: ignore - self, - from_token: torch.Tensor, - path_nodes: torch.Tensor, - to_token: torch.Tensor, - contexts_per_label: torch.Tensor, - output_length: int, - target_sequence: torch.Tensor = None, + self, + from_token: torch.Tensor, + path_nodes: torch.Tensor, + to_token: torch.Tensor, + contexts_per_label: torch.Tensor, + output_length: int, + target_sequence: torch.Tensor = None, ) -> Tuple[torch.Tensor, torch.Tensor]: encoded_paths = self._encoder(from_token, path_nodes, to_token) output_logits, attention_weights = self._decoder( @@ -92,7 +92,7 @@ def forward( # type: ignore # ========== Model step ========== def logits_from_batch( - self, batch: BatchedLabeledPathContext, target_sequence: Optional[torch.Tensor] = None + self, batch: BatchedLabeledPathContext, target_sequence: Optional[torch.Tensor] = None ) -> Tuple[torch.Tensor, torch.Tensor]: return self( batch.from_token, diff --git a/code2seq/model/comment_code2seq.py b/code2seq/model/comment_code2seq.py index d713b92..aada369 100644 --- a/code2seq/model/comment_code2seq.py +++ b/code2seq/model/comment_code2seq.py @@ -15,11 +15,11 @@ class CommentCode2Seq(Code2Seq): def __init__( - self, - model_config: DictConfig, - optimizer_config: DictConfig, - vocabulary: Vocabulary, - teacher_forcing: float = 0.0, + self, + model_config: DictConfig, + optimizer_config: DictConfig, + vocabulary: Vocabulary, + teacher_forcing: float = 0.0, ): super().__init__(model_config, optimizer_config, vocabulary, teacher_forcing) @@ -34,9 +34,7 @@ def __init__( } # TODO add concatenation and rouge-L metric - metrics.update( - {f"{holdout}_chrf": CommentChrF(tokenizer) for holdout in ["val", "test"]} - ) + metrics.update({f"{holdout}_chrf": CommentChrF(tokenizer) for holdout in ["val", "test"]}) self._metrics = MetricCollection(metrics) self._encoder = self._get_encoder(model_config) From 24c8484e6c768dacab8f5344aba0c80b40a660d6 Mon Sep 17 00:00:00 2001 From: malodetz Date: Wed, 20 Jul 2022 13:07:29 +0300 Subject: [PATCH 12/33] Fixing chrf and f1 --- code2seq/data/comment_path_context_dataset.py | 5 ++--- code2seq/model/code2seq.py | 14 ++++++++------ code2seq/model/comment_code2seq.py | 12 ++++++------ config/comment-code2seq-java-large.yaml | 6 +++--- 4 files changed, 19 insertions(+), 18 deletions(-) diff --git a/code2seq/data/comment_path_context_dataset.py b/code2seq/data/comment_path_context_dataset.py index ba16dac..bf1c5e0 100644 --- a/code2seq/data/comment_path_context_dataset.py +++ b/code2seq/data/comment_path_context_dataset.py @@ -4,17 +4,16 @@ from code2seq.data.path_context_dataset import PathContextDataset +tokenizer = RobertaTokenizerFast.from_pretrained("microsoft/codebert-base") + class CommentPathContextDataset(PathContextDataset): @staticmethod def tokenize_label(raw_label: str, vocab: Dict[str, int], max_parts: Optional[int]) -> List[int]: - tokenizer = RobertaTokenizerFast.from_pretrained("microsoft/codebert-base") - label_with_spaces = " ".join(raw_label.split(PathContextDataset._separator)) label_tokens = tokenizer.tokenize(label_with_spaces) if max_parts is None: max_parts = len(label_tokens) label_tokens = [tokenizer.bos_token] + label_tokens[: max_parts - 2] + [tokenizer.eos_token] label_tokens += [tokenizer.pad_token] * (max_parts - len(label_tokens)) - return tokenizer.convert_tokens_to_ids(label_tokens) diff --git a/code2seq/model/code2seq.py b/code2seq/model/code2seq.py index 1ac3963..aa644a1 100644 --- a/code2seq/model/code2seq.py +++ b/code2seq/model/code2seq.py @@ -35,26 +35,26 @@ def __init__( if vocabulary.SOS not in vocabulary.label_to_id: raise ValueError(f"Can't find SOS token in label to id vocabulary") - self.__pad_idx = vocabulary.label_to_id[vocabulary.PAD] + self._pad_idx = vocabulary.label_to_id[vocabulary.PAD] eos_idx = vocabulary.label_to_id[vocabulary.EOS] ignore_idx = [vocabulary.label_to_id[vocabulary.SOS], vocabulary.label_to_id[vocabulary.UNK]] metrics: Dict[str, Metric] = { - f"{holdout}_f1": SequentialF1Score(pad_idx=self.__pad_idx, eos_idx=eos_idx, ignore_idx=ignore_idx) + f"{holdout}_f1": SequentialF1Score(pad_idx=self._pad_idx, eos_idx=eos_idx, ignore_idx=ignore_idx) for holdout in ["train", "val", "test"] } id2label = {v: k for k, v in vocabulary.label_to_id.items()} metrics.update( - {f"{holdout}_chrf": ChrF(id2label, ignore_idx + [self.__pad_idx, eos_idx]) for holdout in ["val", "test"]} + {f"{holdout}_chrf": ChrF(id2label, ignore_idx + [self._pad_idx, eos_idx]) for holdout in ["val", "test"]} ) self._metrics = MetricCollection(metrics) self._encoder = self._get_encoder(model_config) - decoder_step = LSTMDecoderStep(model_config, len(vocabulary.label_to_id), self.__pad_idx) + decoder_step = LSTMDecoderStep(model_config, len(vocabulary.label_to_id), self._pad_idx) self._decoder = Decoder( decoder_step, len(vocabulary.label_to_id), vocabulary.label_to_id[vocabulary.SOS], teacher_forcing ) - self.__loss = SequenceCrossEntropyLoss(self.__pad_idx, reduction="batch-mean") + self._loss = SequenceCrossEntropyLoss(self._pad_idx, reduction="batch-mean") @property def vocabulary(self) -> Vocabulary: @@ -107,7 +107,9 @@ def _shared_step(self, batch: BatchedLabeledPathContext, step: str) -> Dict: target_sequence = batch.labels if step == "train" else None # [seq length; batch size; vocab size] logits, _ = self.logits_from_batch(batch, target_sequence) - result = {f"{step}/loss": self.__loss(logits[1:], batch.labels[1:])} + logits = logits[1:] + batch.labels = batch.labels[1:] + result = {f"{step}/loss": self._loss(logits, batch.labels)} with torch.no_grad(): prediction = logits.argmax(-1) diff --git a/code2seq/model/comment_code2seq.py b/code2seq/model/comment_code2seq.py index aada369..802efcd 100644 --- a/code2seq/model/comment_code2seq.py +++ b/code2seq/model/comment_code2seq.py @@ -25,11 +25,11 @@ def __init__( tokenizer = RobertaTokenizerFast.from_pretrained("microsoft/codebert-base") - self.__pad_idx = tokenizer.pad_token_id + self._pad_idx = tokenizer.pad_token_id eos_idx = tokenizer.eos_token_id ignore_idx = [tokenizer.bos_token_id, tokenizer.unk_token_id] metrics: Dict[str, Metric] = { - f"{holdout}_f1": SequentialF1Score(pad_idx=self.__pad_idx, eos_idx=eos_idx, ignore_idx=ignore_idx) + f"{holdout}_f1": SequentialF1Score(pad_idx=self._pad_idx, eos_idx=eos_idx, ignore_idx=ignore_idx) for holdout in ["train", "val", "test"] } @@ -39,10 +39,10 @@ def __init__( self._encoder = self._get_encoder(model_config) output_size = len(tokenizer.get_vocab()) - decoder_step = LSTMDecoderStep(model_config, output_size, self.__pad_idx) + decoder_step = LSTMDecoderStep(model_config, output_size, self._pad_idx) self._decoder = Decoder(decoder_step, output_size, tokenizer.eos_token_id, teacher_forcing) - self.__loss = SequenceCrossEntropyLoss(self.__pad_idx, reduction="batch-mean") + self._loss = SequenceCrossEntropyLoss(self._pad_idx, reduction="batch-mean") class CommentChrF(Metric): @@ -70,8 +70,8 @@ def update(self, predicted: torch.Tensor, target: torch.Tensor): target_seq = [token.item() for token in target[:, batch_idx]] predicted_seq = [token.item() for token in predicted[:, batch_idx]] - target_str = " ".join(self.__tokenizer.convert_ids_to_tokens(target_seq, skip_special_tokens=True)) - predicted_str = " ".join(self.__tokenizer.convert_ids_to_tokens(predicted_seq, skip_special_tokens=True)) + target_str = self.__tokenizer.decode(target_seq, skip_special_tokens=True) + predicted_str = self.__tokenizer.decode(predicted_seq, skip_special_tokens=True) if target_str == "": # Empty target string mean that the original string encoded only with token diff --git a/config/comment-code2seq-java-large.yaml b/config/comment-code2seq-java-large.yaml index 4169c2c..1342a1e 100644 --- a/config/comment-code2seq-java-large.yaml +++ b/config/comment-code2seq-java-large.yaml @@ -16,7 +16,7 @@ data: num_workers: 4 # Each token appears at least 10 times (99.2% coverage) - labels_count: 10 + labels_count: 1 max_label_parts: 256 # Each token appears at least 1000 times (99.5% coverage) tokens_count: 1 @@ -26,8 +26,8 @@ data: max_context: 200 random_context: true - batch_size: 8 - test_batch_size: 8 + batch_size: 16 + test_batch_size: 16 model: # Encoder From 33db6b547e1f758321a798707ca964e27758f36f Mon Sep 17 00:00:00 2001 From: malodetz Date: Thu, 28 Jul 2022 11:54:44 +0300 Subject: [PATCH 13/33] Preliminary new tokenizer --- .../data/comment_path_context_data_module.py | 61 +++++++++++++++++++ code2seq/data/vocabulary.py | 16 +++++ 2 files changed, 77 insertions(+) diff --git a/code2seq/data/comment_path_context_data_module.py b/code2seq/data/comment_path_context_data_module.py index 2e40dc1..cfc9be7 100644 --- a/code2seq/data/comment_path_context_data_module.py +++ b/code2seq/data/comment_path_context_data_module.py @@ -1,9 +1,70 @@ +import pickle +from collections import Counter +from os.path import join, exists, dirname +from typing import Dict, Counter as TCounter, Type + +from commode_utils.vocabulary import BaseVocabulary +from tqdm.auto import tqdm + +from commode_utils.filesystem import count_lines_in_file +from omegaconf import DictConfig +from transformers import RobertaTokenizerFast + from code2seq.data.comment_path_context_dataset import CommentPathContextDataset from code2seq.data.path_context_data_module import PathContextDataModule +from code2seq.data.vocabulary import CommentVocabulary + + +def _build_from_scratch(train_data: str, labels_count: int, vocabulary_cls: Type[BaseVocabulary]): + total_samples = count_lines_in_file(train_data) + counters: Dict[str, TCounter[str]] = { + key: Counter() for key in [vocabulary_cls.LABEL, vocabulary_cls.TOKEN, vocabulary_cls.NODE] + } + with open(train_data, "r") as f_in: + for raw_sample in tqdm(f_in, total=total_samples): + vocabulary_cls.process_raw_sample(raw_sample, counters) + + training_corpus = [] + good_labels_count = 0 + for string, amount in counters[vocabulary_cls.LABEL].items(): + if amount >= labels_count: + training_corpus.extend([string] * amount) + good_labels_count += 1 + old_tokenizer = RobertaTokenizerFast.from_pretrained("microsoft/codebert-base") + tokenizer = old_tokenizer.train_new_from_iterator(training_corpus, 4 * good_labels_count) + + for feature, counter in counters.items(): + print(f"Count {len(counter)} {feature}, top-5: {counter.most_common(5)}") + + dataset_dir = dirname(train_data) + vocabulary_file = join(dataset_dir, vocabulary_cls.vocab_filename) + with open(vocabulary_file, "wb") as f_out: + pickle.dump(counters, f_out) + pickle.dump(tokenizer, f_out) class CommentPathContextDataModule(PathContextDataModule): + _vocabulary: CommentVocabulary + + def __init__(self, data_dir: str, config: DictConfig): + super().__init__(data_dir, config) + def _create_dataset(self, holdout_file: str, random_context: bool) -> CommentPathContextDataset: if self._vocabulary is None: raise RuntimeError(f"Setup vocabulary before creating data loaders") return CommentPathContextDataset(holdout_file, self._config, self._vocabulary, random_context) + + def setup_vocabulary(self) -> CommentVocabulary: + if not exists(join(self._data_dir, CommentVocabulary.vocab_filename)): + print("Can't find vocabulary, collect it from train holdout") + _build_from_scratch( + join(self._data_dir, f"{self._train}.c2s"), self._config.labels_count, CommentVocabulary + ) + vocabulary_path = join(self._data_dir, CommentVocabulary.vocab_filename) + return CommentVocabulary(vocabulary_path, self._config.labels_count, self._config.tokens_count) + + @property + def vocabulary(self) -> CommentVocabulary: + if self._vocabulary is None: + raise RuntimeError(f"Setup data module for initializing vocabulary") + return self._vocabulary diff --git a/code2seq/data/vocabulary.py b/code2seq/data/vocabulary.py index f0a575a..95ae4d8 100644 --- a/code2seq/data/vocabulary.py +++ b/code2seq/data/vocabulary.py @@ -1,3 +1,4 @@ +import pickle from argparse import ArgumentParser from collections import Counter from os.path import dirname, join @@ -5,6 +6,7 @@ from typing import Dict, Counter as CounterType, Optional, List from commode_utils.vocabulary import BaseVocabulary, build_from_scratch +from transformers import PreTrainedTokenizerFast class Vocabulary(BaseVocabulary): @@ -71,6 +73,20 @@ def process_raw_sample(raw_sample: str, counters: Dict[str, CounterType[str]]): TypedVocabulary._process_raw_sample(raw_sample, counters, context_seq) +class CommentVocabulary(Vocabulary): + def __init__( + self, + vocabulary_file: str, + labels_count: Optional[int] = None, + tokens_count: Optional[int] = None, + ): + super().__init__(vocabulary_file, labels_count, tokens_count) + with open(vocabulary_file, "rb") as f_in: + pickle.load(f_in) + self.tokenizer: PreTrainedTokenizerFast = pickle.load(f_in) + print(len(self.tokenizer.vocab)) + + def convert_from_vanilla(vocabulary_path: str): counters: Dict[str, CounterType[str]] = {} with open(vocabulary_path, "rb") as dict_file: From dc15cdba012a1a455506152465121401d50f705d Mon Sep 17 00:00:00 2001 From: malodetz Date: Thu, 28 Jul 2022 12:07:38 +0300 Subject: [PATCH 14/33] Complete new tokenizer --- code2seq/data/comment_path_context_dataset.py | 14 +++++++++----- code2seq/data/vocabulary.py | 1 - code2seq/model/comment_code2seq.py | 7 +++---- 3 files changed, 12 insertions(+), 10 deletions(-) diff --git a/code2seq/data/comment_path_context_dataset.py b/code2seq/data/comment_path_context_dataset.py index bf1c5e0..a3e351f 100644 --- a/code2seq/data/comment_path_context_dataset.py +++ b/code2seq/data/comment_path_context_dataset.py @@ -1,19 +1,23 @@ from typing import Dict, List, Optional -from transformers import RobertaTokenizerFast +from code2seq.data.vocabulary import CommentVocabulary +from omegaconf import DictConfig from code2seq.data.path_context_dataset import PathContextDataset -tokenizer = RobertaTokenizerFast.from_pretrained("microsoft/codebert-base") - class CommentPathContextDataset(PathContextDataset): - @staticmethod - def tokenize_label(raw_label: str, vocab: Dict[str, int], max_parts: Optional[int]) -> List[int]: + def __init__(self, data_file: str, config: DictConfig, vocabulary: CommentVocabulary, random_context: bool): + super().__init__(data_file, config, vocabulary, random_context) + self._vocab: CommentVocabulary = vocabulary + + def tokenize_label(self, raw_label: str, vocab: Dict[str, int], max_parts: Optional[int]) -> List[int]: label_with_spaces = " ".join(raw_label.split(PathContextDataset._separator)) + tokenizer = self._vocab.tokenizer label_tokens = tokenizer.tokenize(label_with_spaces) if max_parts is None: max_parts = len(label_tokens) label_tokens = [tokenizer.bos_token] + label_tokens[: max_parts - 2] + [tokenizer.eos_token] label_tokens += [tokenizer.pad_token] * (max_parts - len(label_tokens)) + print(label_tokens) return tokenizer.convert_tokens_to_ids(label_tokens) diff --git a/code2seq/data/vocabulary.py b/code2seq/data/vocabulary.py index 95ae4d8..9e27a52 100644 --- a/code2seq/data/vocabulary.py +++ b/code2seq/data/vocabulary.py @@ -84,7 +84,6 @@ def __init__( with open(vocabulary_file, "rb") as f_in: pickle.load(f_in) self.tokenizer: PreTrainedTokenizerFast = pickle.load(f_in) - print(len(self.tokenizer.vocab)) def convert_from_vanilla(vocabulary_path: str): diff --git a/code2seq/model/comment_code2seq.py b/code2seq/model/comment_code2seq.py index 802efcd..881b36c 100644 --- a/code2seq/model/comment_code2seq.py +++ b/code2seq/model/comment_code2seq.py @@ -9,7 +9,7 @@ from torchmetrics import MetricCollection, Metric from transformers import RobertaTokenizerFast -from code2seq.data.vocabulary import Vocabulary +from code2seq.data.vocabulary import Vocabulary, CommentVocabulary from code2seq.model import Code2Seq @@ -18,13 +18,12 @@ def __init__( self, model_config: DictConfig, optimizer_config: DictConfig, - vocabulary: Vocabulary, + vocabulary: CommentVocabulary, teacher_forcing: float = 0.0, ): super().__init__(model_config, optimizer_config, vocabulary, teacher_forcing) - tokenizer = RobertaTokenizerFast.from_pretrained("microsoft/codebert-base") - + tokenizer = vocabulary.tokenizer self._pad_idx = tokenizer.pad_token_id eos_idx = tokenizer.eos_token_id ignore_idx = [tokenizer.bos_token_id, tokenizer.unk_token_id] From 26dd42f256d80785755ffba1bdd86bd6d0a75068 Mon Sep 17 00:00:00 2001 From: malodetz Date: Thu, 28 Jul 2022 12:36:29 +0300 Subject: [PATCH 15/33] Some fixes --- code2seq/data/comment_path_context_data_module.py | 2 +- code2seq/data/comment_path_context_dataset.py | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/code2seq/data/comment_path_context_data_module.py b/code2seq/data/comment_path_context_data_module.py index cfc9be7..b73c7db 100644 --- a/code2seq/data/comment_path_context_data_module.py +++ b/code2seq/data/comment_path_context_data_module.py @@ -31,7 +31,7 @@ def _build_from_scratch(train_data: str, labels_count: int, vocabulary_cls: Type training_corpus.extend([string] * amount) good_labels_count += 1 old_tokenizer = RobertaTokenizerFast.from_pretrained("microsoft/codebert-base") - tokenizer = old_tokenizer.train_new_from_iterator(training_corpus, 4 * good_labels_count) + tokenizer = old_tokenizer.train_new_from_iterator(training_corpus, 128 * good_labels_count) for feature, counter in counters.items(): print(f"Count {len(counter)} {feature}, top-5: {counter.most_common(5)}") diff --git a/code2seq/data/comment_path_context_dataset.py b/code2seq/data/comment_path_context_dataset.py index a3e351f..c0a8e1c 100644 --- a/code2seq/data/comment_path_context_dataset.py +++ b/code2seq/data/comment_path_context_dataset.py @@ -19,5 +19,4 @@ def tokenize_label(self, raw_label: str, vocab: Dict[str, int], max_parts: Optio max_parts = len(label_tokens) label_tokens = [tokenizer.bos_token] + label_tokens[: max_parts - 2] + [tokenizer.eos_token] label_tokens += [tokenizer.pad_token] * (max_parts - len(label_tokens)) - print(label_tokens) return tokenizer.convert_tokens_to_ids(label_tokens) From f6fa424e865ee491e59e7d9dac10d1159e830928 Mon Sep 17 00:00:00 2001 From: malodetz Date: Thu, 28 Jul 2022 13:23:35 +0300 Subject: [PATCH 16/33] New vocab size --- code2seq/data/comment_path_context_data_module.py | 13 ++++--------- code2seq/model/comment_code2seq.py | 4 ++-- 2 files changed, 6 insertions(+), 11 deletions(-) diff --git a/code2seq/data/comment_path_context_data_module.py b/code2seq/data/comment_path_context_data_module.py index b73c7db..514c979 100644 --- a/code2seq/data/comment_path_context_data_module.py +++ b/code2seq/data/comment_path_context_data_module.py @@ -15,7 +15,7 @@ from code2seq.data.vocabulary import CommentVocabulary -def _build_from_scratch(train_data: str, labels_count: int, vocabulary_cls: Type[BaseVocabulary]): +def _build_from_scratch(train_data: str, vocabulary_cls: Type[BaseVocabulary]): total_samples = count_lines_in_file(train_data) counters: Dict[str, TCounter[str]] = { key: Counter() for key in [vocabulary_cls.LABEL, vocabulary_cls.TOKEN, vocabulary_cls.NODE] @@ -25,13 +25,10 @@ def _build_from_scratch(train_data: str, labels_count: int, vocabulary_cls: Type vocabulary_cls.process_raw_sample(raw_sample, counters) training_corpus = [] - good_labels_count = 0 for string, amount in counters[vocabulary_cls.LABEL].items(): - if amount >= labels_count: - training_corpus.extend([string] * amount) - good_labels_count += 1 + training_corpus.extend([string] * amount) old_tokenizer = RobertaTokenizerFast.from_pretrained("microsoft/codebert-base") - tokenizer = old_tokenizer.train_new_from_iterator(training_corpus, 128 * good_labels_count) + tokenizer = old_tokenizer.train_new_from_iterator(training_corpus, 20000) for feature, counter in counters.items(): print(f"Count {len(counter)} {feature}, top-5: {counter.most_common(5)}") @@ -57,9 +54,7 @@ def _create_dataset(self, holdout_file: str, random_context: bool) -> CommentPat def setup_vocabulary(self) -> CommentVocabulary: if not exists(join(self._data_dir, CommentVocabulary.vocab_filename)): print("Can't find vocabulary, collect it from train holdout") - _build_from_scratch( - join(self._data_dir, f"{self._train}.c2s"), self._config.labels_count, CommentVocabulary - ) + _build_from_scratch(join(self._data_dir, f"{self._train}.c2s"), CommentVocabulary) vocabulary_path = join(self._data_dir, CommentVocabulary.vocab_filename) return CommentVocabulary(vocabulary_path, self._config.labels_count, self._config.tokens_count) diff --git a/code2seq/model/comment_code2seq.py b/code2seq/model/comment_code2seq.py index 881b36c..d4728b8 100644 --- a/code2seq/model/comment_code2seq.py +++ b/code2seq/model/comment_code2seq.py @@ -9,7 +9,7 @@ from torchmetrics import MetricCollection, Metric from transformers import RobertaTokenizerFast -from code2seq.data.vocabulary import Vocabulary, CommentVocabulary +from code2seq.data.vocabulary import CommentVocabulary from code2seq.model import Code2Seq @@ -37,7 +37,7 @@ def __init__( self._metrics = MetricCollection(metrics) self._encoder = self._get_encoder(model_config) - output_size = len(tokenizer.get_vocab()) + output_size = tokenizer.vocab_size decoder_step = LSTMDecoderStep(model_config, output_size, self._pad_idx) self._decoder = Decoder(decoder_step, output_size, tokenizer.eos_token_id, teacher_forcing) From e8b677d4e8b08014cfe2c12a9eb05bb6f1b4d88d Mon Sep 17 00:00:00 2001 From: malodetz Date: Sun, 7 Aug 2022 20:35:23 +0300 Subject: [PATCH 17/33] Add tokenizer to config --- code2seq/data/comment_path_context_data_module.py | 11 +++++++---- config/comment-code2seq-java-large.yaml | 5 ++++- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/code2seq/data/comment_path_context_data_module.py b/code2seq/data/comment_path_context_data_module.py index 514c979..0db718a 100644 --- a/code2seq/data/comment_path_context_data_module.py +++ b/code2seq/data/comment_path_context_data_module.py @@ -15,7 +15,7 @@ from code2seq.data.vocabulary import CommentVocabulary -def _build_from_scratch(train_data: str, vocabulary_cls: Type[BaseVocabulary]): +def _build_from_scratch(config: DictConfig, train_data: str, vocabulary_cls: Type[BaseVocabulary]): total_samples = count_lines_in_file(train_data) counters: Dict[str, TCounter[str]] = { key: Counter() for key in [vocabulary_cls.LABEL, vocabulary_cls.TOKEN, vocabulary_cls.NODE] @@ -27,8 +27,11 @@ def _build_from_scratch(train_data: str, vocabulary_cls: Type[BaseVocabulary]): training_corpus = [] for string, amount in counters[vocabulary_cls.LABEL].items(): training_corpus.extend([string] * amount) - old_tokenizer = RobertaTokenizerFast.from_pretrained("microsoft/codebert-base") - tokenizer = old_tokenizer.train_new_from_iterator(training_corpus, 20000) + old_tokenizer = RobertaTokenizerFast.from_pretrained(config.base_tokenizer) + if config.train_new_tokenizer: + tokenizer = old_tokenizer.train_new_from_iterator(training_corpus, 20000) + else: + tokenizer = old_tokenizer for feature, counter in counters.items(): print(f"Count {len(counter)} {feature}, top-5: {counter.most_common(5)}") @@ -54,7 +57,7 @@ def _create_dataset(self, holdout_file: str, random_context: bool) -> CommentPat def setup_vocabulary(self) -> CommentVocabulary: if not exists(join(self._data_dir, CommentVocabulary.vocab_filename)): print("Can't find vocabulary, collect it from train holdout") - _build_from_scratch(join(self._data_dir, f"{self._train}.c2s"), CommentVocabulary) + _build_from_scratch(self._config, join(self._data_dir, f"{self._train}.c2s"), CommentVocabulary) vocabulary_path = join(self._data_dir, CommentVocabulary.vocab_filename) return CommentVocabulary(vocabulary_path, self._config.labels_count, self._config.tokens_count) diff --git a/config/comment-code2seq-java-large.yaml b/config/comment-code2seq-java-large.yaml index 1342a1e..1e3917e 100644 --- a/config/comment-code2seq-java-large.yaml +++ b/config/comment-code2seq-java-large.yaml @@ -8,13 +8,16 @@ progress_bar_refresh_rate: 1 print_config: true wandb: - project: Code2Seq -- java-med + project: comment-code2seq group: null offline: false data: num_workers: 4 + base_tokenizer: "microsoft/codebert-base" + train_new_tokenizer: true + # Each token appears at least 10 times (99.2% coverage) labels_count: 1 max_label_parts: 256 From b1b2e27d44397a40b0dc718d34114afa1653da42 Mon Sep 17 00:00:00 2001 From: malodetz Date: Mon, 8 Aug 2022 17:36:23 +0300 Subject: [PATCH 18/33] Move chrf metric --- code2seq/model/comment_code2seq.py | 43 +----------------------------- code2seq/model/modules/metrics.py | 43 ++++++++++++++++++++++++++++++ 2 files changed, 44 insertions(+), 42 deletions(-) create mode 100644 code2seq/model/modules/metrics.py diff --git a/code2seq/model/comment_code2seq.py b/code2seq/model/comment_code2seq.py index d4728b8..877b37d 100644 --- a/code2seq/model/comment_code2seq.py +++ b/code2seq/model/comment_code2seq.py @@ -1,16 +1,14 @@ from typing import Dict -import torch from commode_utils.losses import SequenceCrossEntropyLoss from commode_utils.metrics import SequentialF1Score from commode_utils.modules import LSTMDecoderStep, Decoder from omegaconf import DictConfig -from sacrebleu import CHRF from torchmetrics import MetricCollection, Metric -from transformers import RobertaTokenizerFast from code2seq.data.vocabulary import CommentVocabulary from code2seq.model import Code2Seq +from code2seq.model.modules.metrics import CommentChrF class CommentCode2Seq(Code2Seq): @@ -42,42 +40,3 @@ def __init__( self._decoder = Decoder(decoder_step, output_size, tokenizer.eos_token_id, teacher_forcing) self._loss = SequenceCrossEntropyLoss(self._pad_idx, reduction="batch-mean") - - -class CommentChrF(Metric): - def __init__(self, tokenizer: RobertaTokenizerFast, **kwargs): - super().__init__(**kwargs) - self.__tokenizer = tokenizer - self.__chrf = CHRF() - - # Metric states - self.add_state("chrf", default=torch.tensor(0, dtype=torch.float), dist_reduce_fx="sum") - self.add_state("count", default=torch.tensor(0), dist_reduce_fx="sum") - - def update(self, predicted: torch.Tensor, target: torch.Tensor): - """Calculated ChrF metric on predicted tensor w.r.t. target tensor. - - :param predicted: [pred seq len; batch size] -- tensor with predicted tokens - :param target: [target seq len; batch size] -- tensor with ground truth tokens - :return: - """ - batch_size = target.shape[1] - if predicted.shape[1] != batch_size: - raise ValueError(f"Wrong batch size for prediction (expected: {batch_size}, actual: {predicted.shape[1]})") - - for batch_idx in range(batch_size): - target_seq = [token.item() for token in target[:, batch_idx]] - predicted_seq = [token.item() for token in predicted[:, batch_idx]] - - target_str = self.__tokenizer.decode(target_seq, skip_special_tokens=True) - predicted_str = self.__tokenizer.decode(predicted_seq, skip_special_tokens=True) - - if target_str == "": - # Empty target string mean that the original string encoded only with token - continue - - self.chrf += self.__chrf.sentence_score(predicted_str, [target_str]).score - self.count += 1 - - def compute(self) -> torch.Tensor: - return self.chrf / self.count diff --git a/code2seq/model/modules/metrics.py b/code2seq/model/modules/metrics.py new file mode 100644 index 0000000..17cc7e2 --- /dev/null +++ b/code2seq/model/modules/metrics.py @@ -0,0 +1,43 @@ +import torch +from torchmetrics import Metric +from sacrebleu import CHRF +from transformers import RobertaTokenizerFast + + +class CommentChrF(Metric): + def __init__(self, tokenizer: RobertaTokenizerFast, **kwargs): + super().__init__(**kwargs) + self.__tokenizer = tokenizer + self.__chrf = CHRF() + + # Metric states + self.add_state("chrf", default=torch.tensor(0, dtype=torch.float), dist_reduce_fx="sum") + self.add_state("count", default=torch.tensor(0), dist_reduce_fx="sum") + + def update(self, predicted: torch.Tensor, target: torch.Tensor): + """Calculated ChrF metric on predicted tensor w.r.t. target tensor. + + :param predicted: [pred seq len; batch size] -- tensor with predicted tokens + :param target: [target seq len; batch size] -- tensor with ground truth tokens + :return: + """ + batch_size = target.shape[1] + if predicted.shape[1] != batch_size: + raise ValueError(f"Wrong batch size for prediction (expected: {batch_size}, actual: {predicted.shape[1]})") + + for batch_idx in range(batch_size): + target_seq = [token.item() for token in target[:, batch_idx]] + predicted_seq = [token.item() for token in predicted[:, batch_idx]] + + target_str = self.__tokenizer.decode(target_seq, skip_special_tokens=True) + predicted_str = self.__tokenizer.decode(predicted_seq, skip_special_tokens=True) + + if target_str == "": + # Empty target string mean that the original string encoded only with token + continue + + self.chrf += self.__chrf.sentence_score(predicted_str, [target_str]).score + self.count += 1 + + def compute(self) -> torch.Tensor: + return self.chrf / self.count From e70f96d5474a81279f1773e2c1d48abcd9029a5b Mon Sep 17 00:00:00 2001 From: malodetz Date: Mon, 8 Aug 2022 18:05:13 +0300 Subject: [PATCH 19/33] Better tokenization --- code2seq/data/comment_path_context_data_module.py | 2 +- code2seq/data/comment_path_context_dataset.py | 14 +++++++------- config/comment-code2seq-java-large.yaml | 1 + 3 files changed, 9 insertions(+), 8 deletions(-) diff --git a/code2seq/data/comment_path_context_data_module.py b/code2seq/data/comment_path_context_data_module.py index 0db718a..6de9f84 100644 --- a/code2seq/data/comment_path_context_data_module.py +++ b/code2seq/data/comment_path_context_data_module.py @@ -29,7 +29,7 @@ def _build_from_scratch(config: DictConfig, train_data: str, vocabulary_cls: Typ training_corpus.extend([string] * amount) old_tokenizer = RobertaTokenizerFast.from_pretrained(config.base_tokenizer) if config.train_new_tokenizer: - tokenizer = old_tokenizer.train_new_from_iterator(training_corpus, 20000) + tokenizer = old_tokenizer.train_new_from_iterator(training_corpus, config.max_tokenizer_vocab) else: tokenizer = old_tokenizer diff --git a/code2seq/data/comment_path_context_dataset.py b/code2seq/data/comment_path_context_dataset.py index c0a8e1c..11b57e2 100644 --- a/code2seq/data/comment_path_context_dataset.py +++ b/code2seq/data/comment_path_context_dataset.py @@ -12,11 +12,11 @@ def __init__(self, data_file: str, config: DictConfig, vocabulary: CommentVocabu self._vocab: CommentVocabulary = vocabulary def tokenize_label(self, raw_label: str, vocab: Dict[str, int], max_parts: Optional[int]) -> List[int]: - label_with_spaces = " ".join(raw_label.split(PathContextDataset._separator)) tokenizer = self._vocab.tokenizer - label_tokens = tokenizer.tokenize(label_with_spaces) - if max_parts is None: - max_parts = len(label_tokens) - label_tokens = [tokenizer.bos_token] + label_tokens[: max_parts - 2] + [tokenizer.eos_token] - label_tokens += [tokenizer.pad_token] * (max_parts - len(label_tokens)) - return tokenizer.convert_tokens_to_ids(label_tokens) + tokenized_snippet = tokenizer( + raw_label.replace(PathContextDataset._separator, " "), + add_special_tokens=True, + padding="max_length" if max_parts else "do_not_pad", + max_length=max_parts, + ) + return tokenized_snippet["input_ids"] diff --git a/config/comment-code2seq-java-large.yaml b/config/comment-code2seq-java-large.yaml index 1e3917e..c43ca24 100644 --- a/config/comment-code2seq-java-large.yaml +++ b/config/comment-code2seq-java-large.yaml @@ -17,6 +17,7 @@ data: base_tokenizer: "microsoft/codebert-base" train_new_tokenizer: true + max_tokenizer_vocab: 20000 # Each token appears at least 10 times (99.2% coverage) labels_count: 1 From 5e93981ec335afdfcf271cff585973222cad4865 Mon Sep 17 00:00:00 2001 From: malodetz Date: Wed, 10 Aug 2022 15:22:51 +0300 Subject: [PATCH 20/33] Implement comment transformer decoder --- code2seq/model/comment_code2seq.py | 23 ++++++--- code2seq/model/modules/comment_decoder.py | 57 +++++++++++++++++++++++ config/comment-code2seq-java-large.yaml | 10 ++-- 3 files changed, 79 insertions(+), 11 deletions(-) create mode 100644 code2seq/model/modules/comment_decoder.py diff --git a/code2seq/model/comment_code2seq.py b/code2seq/model/comment_code2seq.py index 877b37d..723bbb5 100644 --- a/code2seq/model/comment_code2seq.py +++ b/code2seq/model/comment_code2seq.py @@ -2,12 +2,12 @@ from commode_utils.losses import SequenceCrossEntropyLoss from commode_utils.metrics import SequentialF1Score -from commode_utils.modules import LSTMDecoderStep, Decoder from omegaconf import DictConfig from torchmetrics import MetricCollection, Metric from code2seq.data.vocabulary import CommentVocabulary from code2seq.model import Code2Seq +from code2seq.model.modules.comment_decoder import CommentDecoder from code2seq.model.modules.metrics import CommentChrF @@ -19,14 +19,19 @@ def __init__( vocabulary: CommentVocabulary, teacher_forcing: float = 0.0, ): - super().__init__(model_config, optimizer_config, vocabulary, teacher_forcing) + super(Code2Seq, self).__init__() + + self.save_hyperparameters() + self._optim_config = optimizer_config + self._vocabulary = vocabulary tokenizer = vocabulary.tokenizer self._pad_idx = tokenizer.pad_token_id - eos_idx = tokenizer.eos_token_id + self._eos_idx = tokenizer.eos_token_id + self._sos_idx = tokenizer.bos_token_id ignore_idx = [tokenizer.bos_token_id, tokenizer.unk_token_id] metrics: Dict[str, Metric] = { - f"{holdout}_f1": SequentialF1Score(pad_idx=self._pad_idx, eos_idx=eos_idx, ignore_idx=ignore_idx) + f"{holdout}_f1": SequentialF1Score(pad_idx=self._pad_idx, eos_idx=self._eos_idx, ignore_idx=ignore_idx) for holdout in ["train", "val", "test"] } @@ -35,8 +40,12 @@ def __init__( self._metrics = MetricCollection(metrics) self._encoder = self._get_encoder(model_config) - output_size = tokenizer.vocab_size - decoder_step = LSTMDecoderStep(model_config, output_size, self._pad_idx) - self._decoder = Decoder(decoder_step, output_size, tokenizer.eos_token_id, teacher_forcing) + self._decoder = CommentDecoder( + model_config, + vocab_size=tokenizer.vocab_size, + pad_token=self._pad_idx, + sos_token=self._sos_idx, + teacher_forcing=teacher_forcing, + ) self._loss = SequenceCrossEntropyLoss(self._pad_idx, reduction="batch-mean") diff --git a/code2seq/model/modules/comment_decoder.py b/code2seq/model/modules/comment_decoder.py new file mode 100644 index 0000000..4575ed0 --- /dev/null +++ b/code2seq/model/modules/comment_decoder.py @@ -0,0 +1,57 @@ +from commode_utils.training import cut_into_segments +from omegaconf import DictConfig +from torch import nn, Tensor, LongTensor +from torch.nn import TransformerDecoder, Embedding, Linear +from torch.nn.modules.transformer import TransformerDecoderLayer, Transformer +from typing import Tuple + + +class CommentDecoder(nn.Module): + def __init__( + self, config: DictConfig, vocab_size: int, pad_token: int, sos_token: int, teacher_forcing: float = 0.0 + ): + super().__init__() + self._pad_token = pad_token + self._sos_token = sos_token + self._teacher_forcing = teacher_forcing + + self._embedding = Embedding(vocab_size, config.decoder_size, padding_idx=pad_token) + decoder_layer = TransformerDecoderLayer( + d_model=config.decoder_size, + nhead=config.decoder_num_heads, + dim_feedforward=config.decoder_dim_feedforward, + dropout=config.decoder_dropout, + batch_first=True, + ) + self._decoder = TransformerDecoder(decoder_layer, config.decoder_num_layers) + self._linear = Linear(config.decoder_size, vocab_size) + + def forward( + self, + encoder_output: Tensor, + segment_sizes: LongTensor, + output_size: int, + target_sequence: Tensor = None, + ) -> Tuple[Tensor, Tensor]: + + batch_size = segment_sizes.shape[0] + + if not self.training: + target_sequence = encoder_output.new_zeros((batch_size, output_size), dtype=int) + else: + target_sequence = target_sequence.permute(1, 0) + + batched_encoder_output, attention_mask = cut_into_segments(encoder_output, segment_sizes) + attentions = batched_encoder_output.new_zeros((output_size, batch_size, attention_mask.shape[1])) + + embedded = self._embedding(target_sequence) + + tgt_mask = Transformer.generate_square_subsequent_mask(output_size).to(target_sequence.get_device()) + tgt_key_padding_mask = target_sequence == self._pad_token + + decoded = self._decoder( + tgt=embedded, memory=batched_encoder_output, tgt_mask=tgt_mask, tgt_key_padding_mask=tgt_key_padding_mask + ) + + output = self._linear(decoded).permute(1, 0, 2) + return output, attentions diff --git a/config/comment-code2seq-java-large.yaml b/config/comment-code2seq-java-large.yaml index c43ca24..8f88cbe 100644 --- a/config/comment-code2seq-java-large.yaml +++ b/config/comment-code2seq-java-large.yaml @@ -10,7 +10,7 @@ print_config: true wandb: project: comment-code2seq group: null - offline: false + offline: true data: num_workers: 4 @@ -42,9 +42,11 @@ model: rnn_num_layers: 1 # Decoder - decoder_size: 320 - decoder_num_layers: 1 - rnn_dropout: 0.5 + decoder_size: 512 + decoder_num_layers: 6 + decoder_dim_feedforward: 2048 + decoder_num_heads: 8 + decoder_dropout: 0.1 optimizer: optimizer: "Momentum" From 033560cd61f8556d5da21a549ac1791b3137fadb Mon Sep 17 00:00:00 2001 From: malodetz Date: Wed, 10 Aug 2022 23:00:36 +0300 Subject: [PATCH 21/33] Greedy decoding for val/test --- code2seq/model/comment_code2seq.py | 4 +- code2seq/model/modules/comment_decoder.py | 91 +++++++++++++++++++---- config/comment-code2seq-java-large.yaml | 8 +- 3 files changed, 82 insertions(+), 21 deletions(-) diff --git a/code2seq/model/comment_code2seq.py b/code2seq/model/comment_code2seq.py index 723bbb5..ffd3a41 100644 --- a/code2seq/model/comment_code2seq.py +++ b/code2seq/model/comment_code2seq.py @@ -29,7 +29,7 @@ def __init__( self._pad_idx = tokenizer.pad_token_id self._eos_idx = tokenizer.eos_token_id self._sos_idx = tokenizer.bos_token_id - ignore_idx = [tokenizer.bos_token_id, tokenizer.unk_token_id] + ignore_idx = [self._sos_idx, tokenizer.unk_token_id] metrics: Dict[str, Metric] = { f"{holdout}_f1": SequentialF1Score(pad_idx=self._pad_idx, eos_idx=self._eos_idx, ignore_idx=ignore_idx) for holdout in ["train", "val", "test"] @@ -48,4 +48,4 @@ def __init__( teacher_forcing=teacher_forcing, ) - self._loss = SequenceCrossEntropyLoss(self._pad_idx, reduction="batch-mean") + self._loss = SequenceCrossEntropyLoss(self._pad_idx, reduction="seq-mean") diff --git a/code2seq/model/modules/comment_decoder.py b/code2seq/model/modules/comment_decoder.py index 4575ed0..9e24530 100644 --- a/code2seq/model/modules/comment_decoder.py +++ b/code2seq/model/modules/comment_decoder.py @@ -1,3 +1,6 @@ +import math + +import torch from commode_utils.training import cut_into_segments from omegaconf import DictConfig from torch import nn, Tensor, LongTensor @@ -6,16 +9,48 @@ from typing import Tuple +class PositionalEncoding(nn.Module): + def __init__(self, emb_size: int, dropout: float, max_token_length: int = 5000): + super(PositionalEncoding, self).__init__() + + den = torch.exp(-torch.arange(0, emb_size, 2) * math.log(10000) / emb_size) + pos = torch.arange(0, max_token_length).reshape(max_token_length, 1) + + pe = torch.zeros((max_token_length, emb_size)) + pe[:, 0::2] = torch.sin(pos * den) + pe[:, 1::2] = torch.cos(pos * den) + pe = pe.unsqueeze(0) + + self.dropout = nn.Dropout(dropout) + self.register_buffer("pe", pe) + + def forward(self, token_embedding: Tensor): + output = token_embedding + self.pe[:, : token_embedding.size(1), :] + return self.dropout(output) + + +class TokenEmbedding(nn.Module): + def __init__(self, vocab_size: int, emb_size): + super(TokenEmbedding, self).__init__() + self.embedding = Embedding(vocab_size, emb_size) + self.emb_size = emb_size + + def forward(self, tokens: Tensor): + return self.embedding(tokens.long()) * math.sqrt(self.emb_size) + + class CommentDecoder(nn.Module): def __init__( self, config: DictConfig, vocab_size: int, pad_token: int, sos_token: int, teacher_forcing: float = 0.0 ): super().__init__() + self._vocab_size = vocab_size self._pad_token = pad_token self._sos_token = sos_token self._teacher_forcing = teacher_forcing - self._embedding = Embedding(vocab_size, config.decoder_size, padding_idx=pad_token) + self._embedding = TokenEmbedding(vocab_size, config.decoder_size) + self._positional_encoding = PositionalEncoding(config.decoder_size, config.decoder_dropout) decoder_layer = TransformerDecoderLayer( d_model=config.decoder_size, nhead=config.decoder_num_heads, @@ -26,6 +61,22 @@ def __init__( self._decoder = TransformerDecoder(decoder_layer, config.decoder_num_layers) self._linear = Linear(config.decoder_size, vocab_size) + def decode( + self, target_sequence: Tensor, batched_encoder_output: Tensor, tgt_mask: Tensor, attention_mask: Tensor + ) -> Tensor: + tgt_key_padding_mask = target_sequence == self._pad_token + + embedded = self._embedding(target_sequence) + positionally_encoded = self._positional_encoding(embedded) + decoded = self._decoder( + tgt=positionally_encoded, + memory=batched_encoder_output, + tgt_mask=tgt_mask, + tgt_key_padding_mask=tgt_key_padding_mask, + memory_key_padding_mask=attention_mask, + ) + return self._linear(decoded) + def forward( self, encoder_output: Tensor, @@ -33,25 +84,35 @@ def forward( output_size: int, target_sequence: Tensor = None, ) -> Tuple[Tensor, Tensor]: - + device = encoder_output.get_device() batch_size = segment_sizes.shape[0] - if not self.training: - target_sequence = encoder_output.new_zeros((batch_size, output_size), dtype=int) - else: - target_sequence = target_sequence.permute(1, 0) - batched_encoder_output, attention_mask = cut_into_segments(encoder_output, segment_sizes) + # TODO fill attentions with smth good attentions = batched_encoder_output.new_zeros((output_size, batch_size, attention_mask.shape[1])) - embedded = self._embedding(target_sequence) + tgt_mask = (Transformer.generate_square_subsequent_mask(output_size)).to(device) - tgt_mask = Transformer.generate_square_subsequent_mask(output_size).to(target_sequence.get_device()) - tgt_key_padding_mask = target_sequence == self._pad_token + if self.training: + target_sequence = target_sequence.permute(1, 0) - decoded = self._decoder( - tgt=embedded, memory=batched_encoder_output, tgt_mask=tgt_mask, tgt_key_padding_mask=tgt_key_padding_mask - ) + output = self.decode(target_sequence, batched_encoder_output, tgt_mask, attention_mask) + else: + output = torch.zeros((batch_size, output_size, self._vocab_size)).to(device) + output[:, 0, self._sos_token] = 1 + + target_sequence = torch.zeros((batch_size, output_size)).to(device) + target_sequence[:, 1:] = self._pad_token + target_sequence[:, 0] = self._sos_token + + for i in range(1, output_size): + logits = self.decode(target_sequence, batched_encoder_output, tgt_mask, attention_mask) + + with torch.no_grad(): + prediction = logits.argmax(-1) + target_sequence[:, i] = prediction[:, i] + output[:, i, :] = logits[:, i, :] + + print(target_sequence) - output = self._linear(decoded).permute(1, 0, 2) - return output, attentions + return output.permute(1, 0, 2), attentions diff --git a/config/comment-code2seq-java-large.yaml b/config/comment-code2seq-java-large.yaml index 8f88cbe..39aa643 100644 --- a/config/comment-code2seq-java-large.yaml +++ b/config/comment-code2seq-java-large.yaml @@ -42,10 +42,10 @@ model: rnn_num_layers: 1 # Decoder - decoder_size: 512 - decoder_num_layers: 6 - decoder_dim_feedforward: 2048 - decoder_num_heads: 8 + decoder_size: 320 + decoder_num_layers: 4 + decoder_dim_feedforward: 1024 + decoder_num_heads: 4 decoder_dropout: 0.1 optimizer: From c74319f955a1ef5fb348ba8077cfcc3eb548d693 Mon Sep 17 00:00:00 2001 From: malodetz Date: Wed, 10 Aug 2022 23:01:35 +0300 Subject: [PATCH 22/33] Small fix --- code2seq/model/modules/comment_decoder.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/code2seq/model/modules/comment_decoder.py b/code2seq/model/modules/comment_decoder.py index 9e24530..7fff0f5 100644 --- a/code2seq/model/modules/comment_decoder.py +++ b/code2seq/model/modules/comment_decoder.py @@ -113,6 +113,4 @@ def forward( target_sequence[:, i] = prediction[:, i] output[:, i, :] = logits[:, i, :] - print(target_sequence) - return output.permute(1, 0, 2), attentions From b1009cd16dc913f5dbb2e7edb94f288df586e159 Mon Sep 17 00:00:00 2001 From: malodetz Date: Fri, 12 Aug 2022 14:58:58 +0300 Subject: [PATCH 23/33] Some fixes --- code2seq/model/comment_code2seq.py | 28 ++++++++++++++++++++++- code2seq/model/modules/comment_decoder.py | 19 +++++++-------- 2 files changed, 37 insertions(+), 10 deletions(-) diff --git a/code2seq/model/comment_code2seq.py b/code2seq/model/comment_code2seq.py index ffd3a41..dd7b9be 100644 --- a/code2seq/model/comment_code2seq.py +++ b/code2seq/model/comment_code2seq.py @@ -1,10 +1,12 @@ from typing import Dict +import torch from commode_utils.losses import SequenceCrossEntropyLoss -from commode_utils.metrics import SequentialF1Score +from commode_utils.metrics import SequentialF1Score, ClassificationMetrics from omegaconf import DictConfig from torchmetrics import MetricCollection, Metric +from code2seq.data.path_context import BatchedLabeledPathContext from code2seq.data.vocabulary import CommentVocabulary from code2seq.model import Code2Seq from code2seq.model.modules.comment_decoder import CommentDecoder @@ -26,6 +28,8 @@ def __init__( self._vocabulary = vocabulary tokenizer = vocabulary.tokenizer + + print(tokenizer.convert_ids_to_tokens([225])) self._pad_idx = tokenizer.pad_token_id self._eos_idx = tokenizer.eos_token_id self._sos_idx = tokenizer.bos_token_id @@ -49,3 +53,25 @@ def __init__( ) self._loss = SequenceCrossEntropyLoss(self._pad_idx, reduction="seq-mean") + + def _shared_step(self, batch: BatchedLabeledPathContext, step: str) -> Dict: + target_sequence = batch.labels if step != "test" else None + # [seq length; batch size; vocab size] + logits, _ = self.logits_from_batch(batch, target_sequence) + # if step == "test": + logits = logits[1:] + # else: + # logits = logits[:-1] + batch.labels = batch.labels[1:] + result = {f"{step}/loss": self._loss(logits, batch.labels)} + + with torch.no_grad(): + prediction = logits.argmax(-1) + metric: ClassificationMetrics = self._metrics[f"{step}_f1"](prediction, batch.labels) + result.update( + {f"{step}/f1": metric.f1_score, f"{step}/precision": metric.precision, f"{step}/recall": metric.recall} + ) + if step != "train": + result[f"{step}/chrf"] = self._metrics[f"{step}_chrf"](prediction, batch.labels) + + return result diff --git a/code2seq/model/modules/comment_decoder.py b/code2seq/model/modules/comment_decoder.py index 7fff0f5..29a49b6 100644 --- a/code2seq/model/modules/comment_decoder.py +++ b/code2seq/model/modules/comment_decoder.py @@ -93,22 +93,23 @@ def forward( tgt_mask = (Transformer.generate_square_subsequent_mask(output_size)).to(device) - if self.training: + if target_sequence is not None: target_sequence = target_sequence.permute(1, 0) output = self.decode(target_sequence, batched_encoder_output, tgt_mask, attention_mask) else: - output = torch.zeros((batch_size, output_size, self._vocab_size)).to(device) - output[:, 0, self._sos_token] = 1 - target_sequence = torch.zeros((batch_size, output_size)).to(device) - target_sequence[:, 1:] = self._pad_token - target_sequence[:, 0] = self._sos_token + with torch.no_grad(): + output = torch.zeros((batch_size, output_size, self._vocab_size)).to(device) + output[:, 0, self._sos_token] = 1 - for i in range(1, output_size): - logits = self.decode(target_sequence, batched_encoder_output, tgt_mask, attention_mask) + target_sequence = torch.zeros((batch_size, output_size)).to(device) + target_sequence[:, 1:] = self._pad_token + target_sequence[:, 0] = self._sos_token + + for i in range(1, output_size): + logits = self.decode(target_sequence, batched_encoder_output, tgt_mask, attention_mask) - with torch.no_grad(): prediction = logits.argmax(-1) target_sequence[:, i] = prediction[:, i] output[:, i, :] = logits[:, i, :] From 6aff1eda893652d9b3031e9af0a5be8bd92775b2 Mon Sep 17 00:00:00 2001 From: malodetz Date: Fri, 12 Aug 2022 15:00:06 +0300 Subject: [PATCH 24/33] Logits cut --- code2seq/model/comment_code2seq.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/code2seq/model/comment_code2seq.py b/code2seq/model/comment_code2seq.py index dd7b9be..8456ce6 100644 --- a/code2seq/model/comment_code2seq.py +++ b/code2seq/model/comment_code2seq.py @@ -58,10 +58,10 @@ def _shared_step(self, batch: BatchedLabeledPathContext, step: str) -> Dict: target_sequence = batch.labels if step != "test" else None # [seq length; batch size; vocab size] logits, _ = self.logits_from_batch(batch, target_sequence) - # if step == "test": - logits = logits[1:] - # else: - # logits = logits[:-1] + if step == "test": + logits = logits[1:] + else: + logits = logits[:-1] batch.labels = batch.labels[1:] result = {f"{step}/loss": self._loss(logits, batch.labels)} From c23e20a3c042ae1c0cc8aa7a0041402c1e2838a9 Mon Sep 17 00:00:00 2001 From: malodetz Date: Fri, 12 Aug 2022 16:00:29 +0300 Subject: [PATCH 25/33] Fixed greedy generation --- code2seq/model/comment_code2seq.py | 5 +---- code2seq/model/modules/comment_decoder.py | 14 ++++++-------- 2 files changed, 7 insertions(+), 12 deletions(-) diff --git a/code2seq/model/comment_code2seq.py b/code2seq/model/comment_code2seq.py index 8456ce6..6d0ef08 100644 --- a/code2seq/model/comment_code2seq.py +++ b/code2seq/model/comment_code2seq.py @@ -58,10 +58,7 @@ def _shared_step(self, batch: BatchedLabeledPathContext, step: str) -> Dict: target_sequence = batch.labels if step != "test" else None # [seq length; batch size; vocab size] logits, _ = self.logits_from_batch(batch, target_sequence) - if step == "test": - logits = logits[1:] - else: - logits = logits[:-1] + logits = logits[:-1] batch.labels = batch.labels[1:] result = {f"{step}/loss": self._loss(logits, batch.labels)} diff --git a/code2seq/model/modules/comment_decoder.py b/code2seq/model/modules/comment_decoder.py index 29a49b6..5cd53bd 100644 --- a/code2seq/model/modules/comment_decoder.py +++ b/code2seq/model/modules/comment_decoder.py @@ -91,27 +91,25 @@ def forward( # TODO fill attentions with smth good attentions = batched_encoder_output.new_zeros((output_size, batch_size, attention_mask.shape[1])) - tgt_mask = (Transformer.generate_square_subsequent_mask(output_size)).to(device) - if target_sequence is not None: target_sequence = target_sequence.permute(1, 0) + tgt_mask = (Transformer.generate_square_subsequent_mask(output_size)).to(device) + output = self.decode(target_sequence, batched_encoder_output, tgt_mask, attention_mask) else: - with torch.no_grad(): output = torch.zeros((batch_size, output_size, self._vocab_size)).to(device) - output[:, 0, self._sos_token] = 1 - target_sequence = torch.zeros((batch_size, output_size)).to(device) - target_sequence[:, 1:] = self._pad_token + target_sequence = torch.zeros((batch_size, 1)).to(device) target_sequence[:, 0] = self._sos_token - for i in range(1, output_size): + for i in range(output_size): + tgt_mask = (Transformer.generate_square_subsequent_mask(i + 1)).to(device) logits = self.decode(target_sequence, batched_encoder_output, tgt_mask, attention_mask) prediction = logits.argmax(-1) - target_sequence[:, i] = prediction[:, i] + target_sequence = torch.cat((target_sequence, prediction[:, i].unsqueeze(1)), dim=1) output[:, i, :] = logits[:, i, :] return output.permute(1, 0, 2), attentions From 37f3db27d346c92fae4f4501f0e4795c7d85280d Mon Sep 17 00:00:00 2001 From: malodetz Date: Fri, 12 Aug 2022 16:34:48 +0300 Subject: [PATCH 26/33] Some train changes to fix --- code2seq/utils/train.py | 5 ++--- config/comment-code2seq-java-large.yaml | 8 ++++---- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/code2seq/utils/train.py b/code2seq/utils/train.py index d9940f2..a70d73d 100644 --- a/code2seq/utils/train.py +++ b/code2seq/utils/train.py @@ -1,4 +1,3 @@ -import torch from commode_utils.callbacks import ModelCheckpointWithUploadCallback, PrintEpochResultCallback from omegaconf import DictConfig, OmegaConf from pytorch_lightning import seed_everything, Trainer, LightningModule, LightningDataModule @@ -24,7 +23,7 @@ def train(model: LightningModule, data_module: LightningDataModule, config: Dict dirpath=wandb_logger.experiment.dir, filename="{epoch:02d}-val_loss={val/loss:.4f}", monitor="val/loss", - every_n_epochs=params.save_every_epoch, + every_n_train_steps=100, save_top_k=-1, auto_insert_metric_name=False, ) @@ -40,7 +39,7 @@ def train(model: LightningModule, data_module: LightningDataModule, config: Dict max_epochs=params.n_epochs, gradient_clip_val=params.clip_norm, deterministic=True, - check_val_every_n_epoch=params.val_every_epoch, + val_check_interval=100, log_every_n_steps=params.log_every_n_steps, logger=wandb_logger, gpus=params.gpu, diff --git a/config/comment-code2seq-java-large.yaml b/config/comment-code2seq-java-large.yaml index 39aa643..46d1fc1 100644 --- a/config/comment-code2seq-java-large.yaml +++ b/config/comment-code2seq-java-large.yaml @@ -20,18 +20,18 @@ data: max_tokenizer_vocab: 20000 # Each token appears at least 10 times (99.2% coverage) - labels_count: 1 + labels_count: 10 max_label_parts: 256 # Each token appears at least 1000 times (99.5% coverage) - tokens_count: 1 + tokens_count: 1000 max_token_parts: 5 path_length: 9 max_context: 200 random_context: true - batch_size: 16 - test_batch_size: 16 + batch_size: 128 + test_batch_size: 128 model: # Encoder From a19e75216c892cd4c3bd128651b4c7e0d236e549 Mon Sep 17 00:00:00 2001 From: malodetz Date: Thu, 25 Aug 2022 11:44:47 +0300 Subject: [PATCH 27/33] Fix train --- code2seq/utils/train.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/code2seq/utils/train.py b/code2seq/utils/train.py index a70d73d..e422d78 100644 --- a/code2seq/utils/train.py +++ b/code2seq/utils/train.py @@ -23,7 +23,7 @@ def train(model: LightningModule, data_module: LightningDataModule, config: Dict dirpath=wandb_logger.experiment.dir, filename="{epoch:02d}-val_loss={val/loss:.4f}", monitor="val/loss", - every_n_train_steps=100, + every_n_epochs=params.save_every_epoch, save_top_k=-1, auto_insert_metric_name=False, ) @@ -40,6 +40,7 @@ def train(model: LightningModule, data_module: LightningDataModule, config: Dict gradient_clip_val=params.clip_norm, deterministic=True, val_check_interval=100, + check_val_every_n_epoch=params.val_every_epoch, log_every_n_steps=params.log_every_n_steps, logger=wandb_logger, gpus=params.gpu, From e43b664110140131bc63255bb2c91e5fc633eb5b Mon Sep 17 00:00:00 2001 From: malodetz Date: Tue, 30 Aug 2022 13:18:56 +0000 Subject: [PATCH 28/33] =?UTF-8?q?=D0=A1onfig=20for=20transformer=20decoder?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- config/comment-code2seq-java-large.yaml | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/config/comment-code2seq-java-large.yaml b/config/comment-code2seq-java-large.yaml index 46d1fc1..e950a3a 100644 --- a/config/comment-code2seq-java-large.yaml +++ b/config/comment-code2seq-java-large.yaml @@ -30,8 +30,8 @@ data: max_context: 200 random_context: true - batch_size: 128 - test_batch_size: 128 + batch_size: 64 + test_batch_size: 64 model: # Encoder @@ -42,10 +42,10 @@ model: rnn_num_layers: 1 # Decoder - decoder_size: 320 + decoder_size: 512 decoder_num_layers: 4 - decoder_dim_feedforward: 1024 - decoder_num_heads: 4 + decoder_dim_feedforward: 2048 + decoder_num_heads: 8 decoder_dropout: 0.1 optimizer: @@ -63,4 +63,4 @@ train: teacher_forcing: 1.0 val_every_epoch: 1 save_every_epoch: 1 - log_every_n_steps: 10 \ No newline at end of file + log_every_n_steps: 10 From c51f14ce5b4f06c90a9d437781778ecea6e07e6e Mon Sep 17 00:00:00 2001 From: malodetz Date: Tue, 30 Aug 2022 13:50:39 +0000 Subject: [PATCH 29/33] Multiple decoders --- code2seq/data/vocabulary.py | 5 +-- code2seq/model/comment_code2seq.py | 43 +++++++++++++------ code2seq/model/modules/path_encoder.py | 7 +-- ...oder.py => transformer_comment_decoder.py} | 12 ++---- code2seq/utils/optimization.py | 2 +- code2seq/utils/train.py | 1 - ...ment-code2seq-transformer-java-large.yaml} | 1 + 7 files changed, 37 insertions(+), 34 deletions(-) rename code2seq/model/modules/{comment_decoder.py => transformer_comment_decoder.py} (94%) rename config/{comment-code2seq-java-large.yaml => comment-code2seq-transformer-java-large.yaml} (97%) diff --git a/code2seq/data/vocabulary.py b/code2seq/data/vocabulary.py index 9e27a52..b0f114e 100644 --- a/code2seq/data/vocabulary.py +++ b/code2seq/data/vocabulary.py @@ -75,10 +75,7 @@ def process_raw_sample(raw_sample: str, counters: Dict[str, CounterType[str]]): class CommentVocabulary(Vocabulary): def __init__( - self, - vocabulary_file: str, - labels_count: Optional[int] = None, - tokens_count: Optional[int] = None, + self, vocabulary_file: str, labels_count: Optional[int] = None, tokens_count: Optional[int] = None, ): super().__init__(vocabulary_file, labels_count, tokens_count) with open(vocabulary_file, "rb") as f_in: diff --git a/code2seq/model/comment_code2seq.py b/code2seq/model/comment_code2seq.py index 6d0ef08..f952bcb 100644 --- a/code2seq/model/comment_code2seq.py +++ b/code2seq/model/comment_code2seq.py @@ -3,13 +3,14 @@ import torch from commode_utils.losses import SequenceCrossEntropyLoss from commode_utils.metrics import SequentialF1Score, ClassificationMetrics +from commode_utils.modules import LSTMDecoderStep, Decoder from omegaconf import DictConfig from torchmetrics import MetricCollection, Metric from code2seq.data.path_context import BatchedLabeledPathContext from code2seq.data.vocabulary import CommentVocabulary from code2seq.model import Code2Seq -from code2seq.model.modules.comment_decoder import CommentDecoder +from code2seq.model.modules.transformer_comment_decoder import TransformerCommentDecoder from code2seq.model.modules.metrics import CommentChrF @@ -29,7 +30,6 @@ def __init__( tokenizer = vocabulary.tokenizer - print(tokenizer.convert_ids_to_tokens([225])) self._pad_idx = tokenizer.pad_token_id self._eos_idx = tokenizer.eos_token_id self._sos_idx = tokenizer.bos_token_id @@ -44,16 +44,25 @@ def __init__( self._metrics = MetricCollection(metrics) self._encoder = self._get_encoder(model_config) - self._decoder = CommentDecoder( - model_config, - vocab_size=tokenizer.vocab_size, - pad_token=self._pad_idx, - sos_token=self._sos_idx, - teacher_forcing=teacher_forcing, - ) + self._decoder = self.get_decoder(model_config, tokenizer.vocab_size, teacher_forcing) self._loss = SequenceCrossEntropyLoss(self._pad_idx, reduction="seq-mean") + def get_decoder(self, model_config: DictConfig, vocab_size: int, teacher_forcing: float) -> torch.nn.Module: + if model_config.decoder_type == "LSTM": + decoder_step = LSTMDecoderStep(model_config, vocab_size, self._pad_idx) + return Decoder(decoder_step, vocab_size, self._sos_idx, teacher_forcing) + elif model_config.decoder_type == "Transformer": + return TransformerCommentDecoder( + model_config, + vocab_size=vocab_size, + pad_token=self._pad_idx, + sos_token=self._sos_idx, + teacher_forcing=teacher_forcing + ) + else: + raise ValueError + def _shared_step(self, batch: BatchedLabeledPathContext, step: str) -> Dict: target_sequence = batch.labels if step != "test" else None # [seq length; batch size; vocab size] @@ -63,12 +72,18 @@ def _shared_step(self, batch: BatchedLabeledPathContext, step: str) -> Dict: result = {f"{step}/loss": self._loss(logits, batch.labels)} with torch.no_grad(): - prediction = logits.argmax(-1) - metric: ClassificationMetrics = self._metrics[f"{step}_f1"](prediction, batch.labels) - result.update( - {f"{step}/f1": metric.f1_score, f"{step}/precision": metric.precision, f"{step}/recall": metric.recall} - ) if step != "train": + prediction = logits.argmax(-1) + metric: ClassificationMetrics = self._metrics[f"{step}_f1"](prediction, batch.labels) + result.update( + { + f"{step}/f1": metric.f1_score, + f"{step}/precision": metric.precision, + f"{step}/recall": metric.recall, + } + ) result[f"{step}/chrf"] = self._metrics[f"{step}_chrf"](prediction, batch.labels) + else: + result.update({f"{step}/f1": 0, f"{step}/precision": 0, f"{step}/recall": 0}) return result diff --git a/code2seq/model/modules/path_encoder.py b/code2seq/model/modules/path_encoder.py index 678236c..da1832c 100644 --- a/code2seq/model/modules/path_encoder.py +++ b/code2seq/model/modules/path_encoder.py @@ -7,12 +7,7 @@ class PathEncoder(nn.Module): def __init__( - self, - config: DictConfig, - n_tokens: int, - token_pad_id: int, - n_nodes: int, - node_pad_id: int, + self, config: DictConfig, n_tokens: int, token_pad_id: int, n_nodes: int, node_pad_id: int, ): super().__init__() self.node_pad_id = node_pad_id diff --git a/code2seq/model/modules/comment_decoder.py b/code2seq/model/modules/transformer_comment_decoder.py similarity index 94% rename from code2seq/model/modules/comment_decoder.py rename to code2seq/model/modules/transformer_comment_decoder.py index 5cd53bd..e62652b 100644 --- a/code2seq/model/modules/comment_decoder.py +++ b/code2seq/model/modules/transformer_comment_decoder.py @@ -4,8 +4,8 @@ from commode_utils.training import cut_into_segments from omegaconf import DictConfig from torch import nn, Tensor, LongTensor -from torch.nn import TransformerDecoder, Embedding, Linear -from torch.nn.modules.transformer import TransformerDecoderLayer, Transformer +from torch.nn import Embedding, Linear +from torch.nn.modules.transformer import TransformerDecoderLayer, Transformer, TransformerDecoder from typing import Tuple @@ -39,7 +39,7 @@ def forward(self, tokens: Tensor): return self.embedding(tokens.long()) * math.sqrt(self.emb_size) -class CommentDecoder(nn.Module): +class TransformerCommentDecoder(nn.Module): def __init__( self, config: DictConfig, vocab_size: int, pad_token: int, sos_token: int, teacher_forcing: float = 0.0 ): @@ -78,11 +78,7 @@ def decode( return self._linear(decoded) def forward( - self, - encoder_output: Tensor, - segment_sizes: LongTensor, - output_size: int, - target_sequence: Tensor = None, + self, encoder_output: Tensor, segment_sizes: LongTensor, output_size: int, target_sequence: Tensor = None, ) -> Tuple[Tensor, Tensor]: device = encoder_output.get_device() batch_size = segment_sizes.shape[0] diff --git a/code2seq/utils/optimization.py b/code2seq/utils/optimization.py index 7b3183a..7d3e501 100644 --- a/code2seq/utils/optimization.py +++ b/code2seq/utils/optimization.py @@ -29,5 +29,5 @@ def configure_optimizers_alon( optimizer = Adam(parameters, optim_config.lr, weight_decay=optim_config.weight_decay) else: raise ValueError(f"Unknown optimizer name: {optim_config.optimizer}, try one of: Adam, Momentum") - scheduler = LambdaLR(optimizer, lr_lambda=lambda epoch: optim_config.decay_gamma**epoch) + scheduler = LambdaLR(optimizer, lr_lambda=lambda epoch: optim_config.decay_gamma ** epoch) return [optimizer], [scheduler] diff --git a/code2seq/utils/train.py b/code2seq/utils/train.py index e422d78..5d552b8 100644 --- a/code2seq/utils/train.py +++ b/code2seq/utils/train.py @@ -39,7 +39,6 @@ def train(model: LightningModule, data_module: LightningDataModule, config: Dict max_epochs=params.n_epochs, gradient_clip_val=params.clip_norm, deterministic=True, - val_check_interval=100, check_val_every_n_epoch=params.val_every_epoch, log_every_n_steps=params.log_every_n_steps, logger=wandb_logger, diff --git a/config/comment-code2seq-java-large.yaml b/config/comment-code2seq-transformer-java-large.yaml similarity index 97% rename from config/comment-code2seq-java-large.yaml rename to config/comment-code2seq-transformer-java-large.yaml index e950a3a..a54552a 100644 --- a/config/comment-code2seq-java-large.yaml +++ b/config/comment-code2seq-transformer-java-large.yaml @@ -42,6 +42,7 @@ model: rnn_num_layers: 1 # Decoder + decoder_type: "Transformer" decoder_size: 512 decoder_num_layers: 4 decoder_dim_feedforward: 2048 From d09dc30324a0d2aa224f5b2f7c172c36556207ec Mon Sep 17 00:00:00 2001 From: malodetz Date: Tue, 30 Aug 2022 14:26:15 +0000 Subject: [PATCH 30/33] Another config --- config/comment-code2seq-lstm-java-large.yaml | 65 ++++++++++++++++++++ 1 file changed, 65 insertions(+) create mode 100644 config/comment-code2seq-lstm-java-large.yaml diff --git a/config/comment-code2seq-lstm-java-large.yaml b/config/comment-code2seq-lstm-java-large.yaml new file mode 100644 index 0000000..9173b22 --- /dev/null +++ b/config/comment-code2seq-lstm-java-large.yaml @@ -0,0 +1,65 @@ +data_folder: ./dataset + +checkpoint: null + +seed: 7 +# Training in notebooks (e.g. Google Colab) may crash with too small value +progress_bar_refresh_rate: 1 +print_config: true + +wandb: + project: comment-code2seq + group: null + offline: true + +data: + num_workers: 4 + + base_tokenizer: "microsoft/codebert-base" + train_new_tokenizer: true + max_tokenizer_vocab: 20000 + + # Each token appears at least 10 times (99.2% coverage) + labels_count: 10 + max_label_parts: 256 + # Each token appears at least 1000 times (99.5% coverage) + tokens_count: 1000 + max_token_parts: 5 + path_length: 9 + + max_context: 200 + random_context: true + + batch_size: 128 + test_batch_size: 128 + +model: + # Encoder + embedding_size: 128 + encoder_dropout: 0.25 + encoder_rnn_size: 128 + use_bi_rnn: true + rnn_num_layers: 1 + + # Decoder + decoder_type: "LSTM" + decoder_size: 320 + decoder_num_layers: 1 + rnn_dropout: 0.5 + +optimizer: + optimizer: "Momentum" + nesterov: true + lr: 0.01 + weight_decay: 0 + decay_gamma: 0.95 + +train: + gpu: 1 + n_epochs: 100 + patience: 10 + clip_norm: 5 + teacher_forcing: 1.0 + val_every_epoch: 1 + save_every_epoch: 1 + log_every_n_steps: 10 \ No newline at end of file From 9bfa25cc80761fee120246fece911a73745aece8 Mon Sep 17 00:00:00 2001 From: malodetz Date: Tue, 30 Aug 2022 16:38:17 +0000 Subject: [PATCH 31/33] Add early stop for greedy generation --- code2seq/model/comment_code2seq.py | 3 ++- .../modules/transformer_comment_decoder.py | 18 +++++++++++++++--- 2 files changed, 17 insertions(+), 4 deletions(-) diff --git a/code2seq/model/comment_code2seq.py b/code2seq/model/comment_code2seq.py index f952bcb..01db29c 100644 --- a/code2seq/model/comment_code2seq.py +++ b/code2seq/model/comment_code2seq.py @@ -58,7 +58,8 @@ def get_decoder(self, model_config: DictConfig, vocab_size: int, teacher_forcing vocab_size=vocab_size, pad_token=self._pad_idx, sos_token=self._sos_idx, - teacher_forcing=teacher_forcing + eos_token=self._eos_idx, + teacher_forcing=teacher_forcing, ) else: raise ValueError diff --git a/code2seq/model/modules/transformer_comment_decoder.py b/code2seq/model/modules/transformer_comment_decoder.py index e62652b..1846591 100644 --- a/code2seq/model/modules/transformer_comment_decoder.py +++ b/code2seq/model/modules/transformer_comment_decoder.py @@ -41,12 +41,19 @@ def forward(self, tokens: Tensor): class TransformerCommentDecoder(nn.Module): def __init__( - self, config: DictConfig, vocab_size: int, pad_token: int, sos_token: int, teacher_forcing: float = 0.0 + self, + config: DictConfig, + vocab_size: int, + pad_token: int, + sos_token: int, + eos_token: int, + teacher_forcing: float = 0.0, ): super().__init__() self._vocab_size = vocab_size self._pad_token = pad_token self._sos_token = sos_token + self._eos_token = eos_token self._teacher_forcing = teacher_forcing self._embedding = TokenEmbedding(vocab_size, config.decoder_size) @@ -99,13 +106,18 @@ def forward( target_sequence = torch.zeros((batch_size, 1)).to(device) target_sequence[:, 0] = self._sos_token + is_ended = torch.zeros(batch_size, dtype=torch.bool) for i in range(output_size): tgt_mask = (Transformer.generate_square_subsequent_mask(i + 1)).to(device) logits = self.decode(target_sequence, batched_encoder_output, tgt_mask, attention_mask) - prediction = logits.argmax(-1) - target_sequence = torch.cat((target_sequence, prediction[:, i].unsqueeze(1)), dim=1) + prediction = logits.argmax(-1)[:, i] + target_sequence = torch.cat((target_sequence, prediction.unsqueeze(1)), dim=1) output[:, i, :] = logits[:, i, :] + is_ended = torch.logical_or(is_ended, (prediction == self._eos_token)) + if torch.count_nonzero(is_ended)[0] == batch_size: + break + return output.permute(1, 0, 2), attentions From b0b7b81716883f8ab4f2315844e04a10425ffc2b Mon Sep 17 00:00:00 2001 From: malodetz Date: Tue, 30 Aug 2022 21:02:45 +0300 Subject: [PATCH 32/33] Early generation stop --- code2seq/data/vocabulary.py | 5 ++++- code2seq/model/modules/path_encoder.py | 7 ++++++- code2seq/model/modules/transformer_comment_decoder.py | 10 +++++++--- code2seq/utils/optimization.py | 2 +- config/comment-code2seq-transformer-java-large.yaml | 2 +- 5 files changed, 19 insertions(+), 7 deletions(-) diff --git a/code2seq/data/vocabulary.py b/code2seq/data/vocabulary.py index b0f114e..9e27a52 100644 --- a/code2seq/data/vocabulary.py +++ b/code2seq/data/vocabulary.py @@ -75,7 +75,10 @@ def process_raw_sample(raw_sample: str, counters: Dict[str, CounterType[str]]): class CommentVocabulary(Vocabulary): def __init__( - self, vocabulary_file: str, labels_count: Optional[int] = None, tokens_count: Optional[int] = None, + self, + vocabulary_file: str, + labels_count: Optional[int] = None, + tokens_count: Optional[int] = None, ): super().__init__(vocabulary_file, labels_count, tokens_count) with open(vocabulary_file, "rb") as f_in: diff --git a/code2seq/model/modules/path_encoder.py b/code2seq/model/modules/path_encoder.py index da1832c..678236c 100644 --- a/code2seq/model/modules/path_encoder.py +++ b/code2seq/model/modules/path_encoder.py @@ -7,7 +7,12 @@ class PathEncoder(nn.Module): def __init__( - self, config: DictConfig, n_tokens: int, token_pad_id: int, n_nodes: int, node_pad_id: int, + self, + config: DictConfig, + n_tokens: int, + token_pad_id: int, + n_nodes: int, + node_pad_id: int, ): super().__init__() self.node_pad_id = node_pad_id diff --git a/code2seq/model/modules/transformer_comment_decoder.py b/code2seq/model/modules/transformer_comment_decoder.py index 1846591..608ff4a 100644 --- a/code2seq/model/modules/transformer_comment_decoder.py +++ b/code2seq/model/modules/transformer_comment_decoder.py @@ -85,7 +85,11 @@ def decode( return self._linear(decoded) def forward( - self, encoder_output: Tensor, segment_sizes: LongTensor, output_size: int, target_sequence: Tensor = None, + self, + encoder_output: Tensor, + segment_sizes: LongTensor, + output_size: int, + target_sequence: Tensor = None, ) -> Tuple[Tensor, Tensor]: device = encoder_output.get_device() batch_size = segment_sizes.shape[0] @@ -106,7 +110,7 @@ def forward( target_sequence = torch.zeros((batch_size, 1)).to(device) target_sequence[:, 0] = self._sos_token - is_ended = torch.zeros(batch_size, dtype=torch.bool) + is_ended = torch.zeros(batch_size, dtype=torch.bool).to(device) for i in range(output_size): tgt_mask = (Transformer.generate_square_subsequent_mask(i + 1)).to(device) @@ -117,7 +121,7 @@ def forward( output[:, i, :] = logits[:, i, :] is_ended = torch.logical_or(is_ended, (prediction == self._eos_token)) - if torch.count_nonzero(is_ended)[0] == batch_size: + if torch.count_nonzero(is_ended).item() == batch_size: break return output.permute(1, 0, 2), attentions diff --git a/code2seq/utils/optimization.py b/code2seq/utils/optimization.py index 7d3e501..7b3183a 100644 --- a/code2seq/utils/optimization.py +++ b/code2seq/utils/optimization.py @@ -29,5 +29,5 @@ def configure_optimizers_alon( optimizer = Adam(parameters, optim_config.lr, weight_decay=optim_config.weight_decay) else: raise ValueError(f"Unknown optimizer name: {optim_config.optimizer}, try one of: Adam, Momentum") - scheduler = LambdaLR(optimizer, lr_lambda=lambda epoch: optim_config.decay_gamma ** epoch) + scheduler = LambdaLR(optimizer, lr_lambda=lambda epoch: optim_config.decay_gamma**epoch) return [optimizer], [scheduler] diff --git a/config/comment-code2seq-transformer-java-large.yaml b/config/comment-code2seq-transformer-java-large.yaml index a54552a..548f0ea 100644 --- a/config/comment-code2seq-transformer-java-large.yaml +++ b/config/comment-code2seq-transformer-java-large.yaml @@ -10,7 +10,7 @@ print_config: true wandb: project: comment-code2seq group: null - offline: true + offline: false data: num_workers: 4 From 54cb71121bb620f615c2ce7fc37dd3a76f2249c4 Mon Sep 17 00:00:00 2001 From: malodetz Date: Sat, 17 Sep 2022 18:33:20 +0300 Subject: [PATCH 33/33] Added predictions --- code2seq/comment_code2seq_wrapper.py | 43 ++++++++++++++++++++++++++-- 1 file changed, 41 insertions(+), 2 deletions(-) diff --git a/code2seq/comment_code2seq_wrapper.py b/code2seq/comment_code2seq_wrapper.py index b9b0a09..1ad1572 100644 --- a/code2seq/comment_code2seq_wrapper.py +++ b/code2seq/comment_code2seq_wrapper.py @@ -15,11 +15,14 @@ def configure_arg_parser() -> ArgumentParser: arg_parser = ArgumentParser() - arg_parser.add_argument("mode", help="Mode to run script", choices=["train", "test"]) + arg_parser.add_argument("mode", help="Mode to run script", choices=["train", "test", "predict"]) arg_parser.add_argument("-c", "--config", help="Path to YAML configuration file", type=str) arg_parser.add_argument( "-p", "--pretrained", help="Path to pretrained model", type=str, required=False, default=None ) + arg_parser.add_argument( + "-o", "--output", help="Output file for predictions", type=str, required=False, default=None + ) return arg_parser @@ -50,6 +53,39 @@ def test_code2seq(model_path: str, config: DictConfig): test(code2seq, data_module, config.seed) +def save_predictions(model_path: str, config: DictConfig, output_path: str): + filter_warnings() + + data_module = CommentPathContextDataModule(config.data_folder, config.data) + tokenizer = data_module.vocabulary.tokenizer + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + code2seq = CommentCode2Seq.load_from_checkpoint(model_path) + code2seq.to(device) + code2seq.eval() + + with open(output_path, "w") as f: + for batch in data_module.test_dataloader(): + data_module.transfer_batch_to_device(batch, device, 0) + logits, _ = code2seq.logits_from_batch(batch, None) + + predictions = logits[:-1].argmax(-1) + targets = batch.labels[1:] + + batch_size = targets.shape[1] + for batch_idx in range(batch_size): + target_seq = [token.item() for token in targets[:, batch_idx]] + predicted_seq = [token.item() for token in predictions[:, batch_idx]] + + target_str = tokenizer.decode(target_seq, skip_special_tokens=True) + predicted_str = tokenizer.decode(predicted_seq, skip_special_tokens=True) + + if target_str == "": + continue + + print(target_str.replace(" ", "|"), predicted_str.replace(" ", "|"), file=f) + + if __name__ == "__main__": __arg_parser = configure_arg_parser() __args = __arg_parser.parse_args() @@ -60,4 +96,7 @@ def test_code2seq(model_path: str, config: DictConfig): train_code2seq(__config) else: assert __args.pretrained is not None - test_code2seq(__args.pretrained, __config) + if __args.mode == "test": + test_code2seq(__args.pretrained, __config) + else: + save_predictions(__args.pretrained, __config, __args.output)