diff --git a/code2seq/comment_code2seq_wrapper.py b/code2seq/comment_code2seq_wrapper.py new file mode 100644 index 0000000..1ad1572 --- /dev/null +++ b/code2seq/comment_code2seq_wrapper.py @@ -0,0 +1,102 @@ +from argparse import ArgumentParser +from typing import cast + +import torch +from commode_utils.common import print_config +from omegaconf import DictConfig, OmegaConf +from pytorch_lightning import seed_everything + +from code2seq.data.comment_path_context_data_module import CommentPathContextDataModule +from code2seq.model.comment_code2seq import CommentCode2Seq +from code2seq.utils.common import filter_warnings +from code2seq.utils.test import test +from code2seq.utils.train import train + + +def configure_arg_parser() -> ArgumentParser: + arg_parser = ArgumentParser() + arg_parser.add_argument("mode", help="Mode to run script", choices=["train", "test", "predict"]) + arg_parser.add_argument("-c", "--config", help="Path to YAML configuration file", type=str) + arg_parser.add_argument( + "-p", "--pretrained", help="Path to pretrained model", type=str, required=False, default=None + ) + arg_parser.add_argument( + "-o", "--output", help="Output file for predictions", type=str, required=False, default=None + ) + return arg_parser + + +def train_code2seq(config: DictConfig): + filter_warnings() + + if config.print_config: + print_config(config, fields=["model", "data", "train", "optimizer"]) + + # Load data module + data_module = CommentPathContextDataModule(config.data_folder, config.data) + + # Load model + code2seq = CommentCode2Seq(config.model, config.optimizer, data_module.vocabulary, config.train.teacher_forcing) + + train(code2seq, data_module, config) + + +def test_code2seq(model_path: str, config: DictConfig): + filter_warnings() + + # Load data module + data_module = CommentPathContextDataModule(config.data_folder, config.data) + + # Load model + code2seq = CommentCode2Seq.load_from_checkpoint(model_path, map_location=torch.device("cpu")) + + test(code2seq, data_module, config.seed) + + +def save_predictions(model_path: str, config: DictConfig, output_path: str): + filter_warnings() + + data_module = CommentPathContextDataModule(config.data_folder, config.data) + tokenizer = data_module.vocabulary.tokenizer + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + code2seq = CommentCode2Seq.load_from_checkpoint(model_path) + code2seq.to(device) + code2seq.eval() + + with open(output_path, "w") as f: + for batch in data_module.test_dataloader(): + data_module.transfer_batch_to_device(batch, device, 0) + logits, _ = code2seq.logits_from_batch(batch, None) + + predictions = logits[:-1].argmax(-1) + targets = batch.labels[1:] + + batch_size = targets.shape[1] + for batch_idx in range(batch_size): + target_seq = [token.item() for token in targets[:, batch_idx]] + predicted_seq = [token.item() for token in predictions[:, batch_idx]] + + target_str = tokenizer.decode(target_seq, skip_special_tokens=True) + predicted_str = tokenizer.decode(predicted_seq, skip_special_tokens=True) + + if target_str == "": + continue + + print(target_str.replace(" ", "|"), predicted_str.replace(" ", "|"), file=f) + + +if __name__ == "__main__": + __arg_parser = configure_arg_parser() + __args = __arg_parser.parse_args() + + __config = cast(DictConfig, OmegaConf.load(__args.config)) + seed_everything(__config.seed) + if __args.mode == "train": + train_code2seq(__config) + else: + assert __args.pretrained is not None + if __args.mode == "test": + test_code2seq(__args.pretrained, __config) + else: + save_predictions(__args.pretrained, __config, __args.output) diff --git a/code2seq/data/comment_path_context_data_module.py b/code2seq/data/comment_path_context_data_module.py new file mode 100644 index 0000000..6de9f84 --- /dev/null +++ b/code2seq/data/comment_path_context_data_module.py @@ -0,0 +1,68 @@ +import pickle +from collections import Counter +from os.path import join, exists, dirname +from typing import Dict, Counter as TCounter, Type + +from commode_utils.vocabulary import BaseVocabulary +from tqdm.auto import tqdm + +from commode_utils.filesystem import count_lines_in_file +from omegaconf import DictConfig +from transformers import RobertaTokenizerFast + +from code2seq.data.comment_path_context_dataset import CommentPathContextDataset +from code2seq.data.path_context_data_module import PathContextDataModule +from code2seq.data.vocabulary import CommentVocabulary + + +def _build_from_scratch(config: DictConfig, train_data: str, vocabulary_cls: Type[BaseVocabulary]): + total_samples = count_lines_in_file(train_data) + counters: Dict[str, TCounter[str]] = { + key: Counter() for key in [vocabulary_cls.LABEL, vocabulary_cls.TOKEN, vocabulary_cls.NODE] + } + with open(train_data, "r") as f_in: + for raw_sample in tqdm(f_in, total=total_samples): + vocabulary_cls.process_raw_sample(raw_sample, counters) + + training_corpus = [] + for string, amount in counters[vocabulary_cls.LABEL].items(): + training_corpus.extend([string] * amount) + old_tokenizer = RobertaTokenizerFast.from_pretrained(config.base_tokenizer) + if config.train_new_tokenizer: + tokenizer = old_tokenizer.train_new_from_iterator(training_corpus, config.max_tokenizer_vocab) + else: + tokenizer = old_tokenizer + + for feature, counter in counters.items(): + print(f"Count {len(counter)} {feature}, top-5: {counter.most_common(5)}") + + dataset_dir = dirname(train_data) + vocabulary_file = join(dataset_dir, vocabulary_cls.vocab_filename) + with open(vocabulary_file, "wb") as f_out: + pickle.dump(counters, f_out) + pickle.dump(tokenizer, f_out) + + +class CommentPathContextDataModule(PathContextDataModule): + _vocabulary: CommentVocabulary + + def __init__(self, data_dir: str, config: DictConfig): + super().__init__(data_dir, config) + + def _create_dataset(self, holdout_file: str, random_context: bool) -> CommentPathContextDataset: + if self._vocabulary is None: + raise RuntimeError(f"Setup vocabulary before creating data loaders") + return CommentPathContextDataset(holdout_file, self._config, self._vocabulary, random_context) + + def setup_vocabulary(self) -> CommentVocabulary: + if not exists(join(self._data_dir, CommentVocabulary.vocab_filename)): + print("Can't find vocabulary, collect it from train holdout") + _build_from_scratch(self._config, join(self._data_dir, f"{self._train}.c2s"), CommentVocabulary) + vocabulary_path = join(self._data_dir, CommentVocabulary.vocab_filename) + return CommentVocabulary(vocabulary_path, self._config.labels_count, self._config.tokens_count) + + @property + def vocabulary(self) -> CommentVocabulary: + if self._vocabulary is None: + raise RuntimeError(f"Setup data module for initializing vocabulary") + return self._vocabulary diff --git a/code2seq/data/comment_path_context_dataset.py b/code2seq/data/comment_path_context_dataset.py new file mode 100644 index 0000000..11b57e2 --- /dev/null +++ b/code2seq/data/comment_path_context_dataset.py @@ -0,0 +1,22 @@ +from typing import Dict, List, Optional + +from code2seq.data.vocabulary import CommentVocabulary +from omegaconf import DictConfig + +from code2seq.data.path_context_dataset import PathContextDataset + + +class CommentPathContextDataset(PathContextDataset): + def __init__(self, data_file: str, config: DictConfig, vocabulary: CommentVocabulary, random_context: bool): + super().__init__(data_file, config, vocabulary, random_context) + self._vocab: CommentVocabulary = vocabulary + + def tokenize_label(self, raw_label: str, vocab: Dict[str, int], max_parts: Optional[int]) -> List[int]: + tokenizer = self._vocab.tokenizer + tokenized_snippet = tokenizer( + raw_label.replace(PathContextDataset._separator, " "), + add_special_tokens=True, + padding="max_length" if max_parts else "do_not_pad", + max_length=max_parts, + ) + return tokenized_snippet["input_ids"] diff --git a/code2seq/data/vocabulary.py b/code2seq/data/vocabulary.py index f0a575a..9e27a52 100644 --- a/code2seq/data/vocabulary.py +++ b/code2seq/data/vocabulary.py @@ -1,3 +1,4 @@ +import pickle from argparse import ArgumentParser from collections import Counter from os.path import dirname, join @@ -5,6 +6,7 @@ from typing import Dict, Counter as CounterType, Optional, List from commode_utils.vocabulary import BaseVocabulary, build_from_scratch +from transformers import PreTrainedTokenizerFast class Vocabulary(BaseVocabulary): @@ -71,6 +73,19 @@ def process_raw_sample(raw_sample: str, counters: Dict[str, CounterType[str]]): TypedVocabulary._process_raw_sample(raw_sample, counters, context_seq) +class CommentVocabulary(Vocabulary): + def __init__( + self, + vocabulary_file: str, + labels_count: Optional[int] = None, + tokens_count: Optional[int] = None, + ): + super().__init__(vocabulary_file, labels_count, tokens_count) + with open(vocabulary_file, "rb") as f_in: + pickle.load(f_in) + self.tokenizer: PreTrainedTokenizerFast = pickle.load(f_in) + + def convert_from_vanilla(vocabulary_path: str): counters: Dict[str, CounterType[str]] = {} with open(vocabulary_path, "rb") as dict_file: diff --git a/code2seq/model/code2seq.py b/code2seq/model/code2seq.py index 603760d..aa644a1 100644 --- a/code2seq/model/code2seq.py +++ b/code2seq/model/code2seq.py @@ -35,26 +35,26 @@ def __init__( if vocabulary.SOS not in vocabulary.label_to_id: raise ValueError(f"Can't find SOS token in label to id vocabulary") - self.__pad_idx = vocabulary.label_to_id[vocabulary.PAD] + self._pad_idx = vocabulary.label_to_id[vocabulary.PAD] eos_idx = vocabulary.label_to_id[vocabulary.EOS] ignore_idx = [vocabulary.label_to_id[vocabulary.SOS], vocabulary.label_to_id[vocabulary.UNK]] metrics: Dict[str, Metric] = { - f"{holdout}_f1": SequentialF1Score(pad_idx=self.__pad_idx, eos_idx=eos_idx, ignore_idx=ignore_idx) + f"{holdout}_f1": SequentialF1Score(pad_idx=self._pad_idx, eos_idx=eos_idx, ignore_idx=ignore_idx) for holdout in ["train", "val", "test"] } id2label = {v: k for k, v in vocabulary.label_to_id.items()} metrics.update( - {f"{holdout}_chrf": ChrF(id2label, ignore_idx + [self.__pad_idx, eos_idx]) for holdout in ["val", "test"]} + {f"{holdout}_chrf": ChrF(id2label, ignore_idx + [self._pad_idx, eos_idx]) for holdout in ["val", "test"]} ) - self.__metrics = MetricCollection(metrics) + self._metrics = MetricCollection(metrics) self._encoder = self._get_encoder(model_config) - decoder_step = LSTMDecoderStep(model_config, len(vocabulary.label_to_id), self.__pad_idx) + decoder_step = LSTMDecoderStep(model_config, len(vocabulary.label_to_id), self._pad_idx) self._decoder = Decoder( decoder_step, len(vocabulary.label_to_id), vocabulary.label_to_id[vocabulary.SOS], teacher_forcing ) - self.__loss = SequenceCrossEntropyLoss(self.__pad_idx, reduction="batch-mean") + self._loss = SequenceCrossEntropyLoss(self._pad_idx, reduction="batch-mean") @property def vocabulary(self) -> Vocabulary: @@ -107,16 +107,18 @@ def _shared_step(self, batch: BatchedLabeledPathContext, step: str) -> Dict: target_sequence = batch.labels if step == "train" else None # [seq length; batch size; vocab size] logits, _ = self.logits_from_batch(batch, target_sequence) - result = {f"{step}/loss": self.__loss(logits[1:], batch.labels[1:])} + logits = logits[1:] + batch.labels = batch.labels[1:] + result = {f"{step}/loss": self._loss(logits, batch.labels)} with torch.no_grad(): prediction = logits.argmax(-1) - metric: ClassificationMetrics = self.__metrics[f"{step}_f1"](prediction, batch.labels) + metric: ClassificationMetrics = self._metrics[f"{step}_f1"](prediction, batch.labels) result.update( {f"{step}/f1": metric.f1_score, f"{step}/precision": metric.precision, f"{step}/recall": metric.recall} ) if step != "train": - result[f"{step}/chrf"] = self.__metrics[f"{step}_chrf"](prediction, batch.labels) + result[f"{step}/chrf"] = self._metrics[f"{step}_chrf"](prediction, batch.labels) return result @@ -140,17 +142,17 @@ def _shared_epoch_end(self, step_outputs: EPOCH_OUTPUT, step: str): with torch.no_grad(): losses = [so if isinstance(so, torch.Tensor) else so["loss"] for so in step_outputs] mean_loss = torch.stack(losses).mean() - metric = self.__metrics[f"{step}_f1"].compute() + metric = self._metrics[f"{step}_f1"].compute() log = { f"{step}/loss": mean_loss, f"{step}/f1": metric.f1_score, f"{step}/precision": metric.precision, f"{step}/recall": metric.recall, } - self.__metrics[f"{step}_f1"].reset() + self._metrics[f"{step}_f1"].reset() if step != "train": - log[f"{step}/chrf"] = self.__metrics[f"{step}_chrf"].compute() - self.__metrics[f"{step}_chrf"].reset() + log[f"{step}/chrf"] = self._metrics[f"{step}_chrf"].compute() + self._metrics[f"{step}_chrf"].reset() self.log_dict(log, on_step=False, on_epoch=True) def training_epoch_end(self, step_outputs: EPOCH_OUTPUT): diff --git a/code2seq/model/comment_code2seq.py b/code2seq/model/comment_code2seq.py new file mode 100644 index 0000000..01db29c --- /dev/null +++ b/code2seq/model/comment_code2seq.py @@ -0,0 +1,90 @@ +from typing import Dict + +import torch +from commode_utils.losses import SequenceCrossEntropyLoss +from commode_utils.metrics import SequentialF1Score, ClassificationMetrics +from commode_utils.modules import LSTMDecoderStep, Decoder +from omegaconf import DictConfig +from torchmetrics import MetricCollection, Metric + +from code2seq.data.path_context import BatchedLabeledPathContext +from code2seq.data.vocabulary import CommentVocabulary +from code2seq.model import Code2Seq +from code2seq.model.modules.transformer_comment_decoder import TransformerCommentDecoder +from code2seq.model.modules.metrics import CommentChrF + + +class CommentCode2Seq(Code2Seq): + def __init__( + self, + model_config: DictConfig, + optimizer_config: DictConfig, + vocabulary: CommentVocabulary, + teacher_forcing: float = 0.0, + ): + super(Code2Seq, self).__init__() + + self.save_hyperparameters() + self._optim_config = optimizer_config + self._vocabulary = vocabulary + + tokenizer = vocabulary.tokenizer + + self._pad_idx = tokenizer.pad_token_id + self._eos_idx = tokenizer.eos_token_id + self._sos_idx = tokenizer.bos_token_id + ignore_idx = [self._sos_idx, tokenizer.unk_token_id] + metrics: Dict[str, Metric] = { + f"{holdout}_f1": SequentialF1Score(pad_idx=self._pad_idx, eos_idx=self._eos_idx, ignore_idx=ignore_idx) + for holdout in ["train", "val", "test"] + } + + # TODO add concatenation and rouge-L metric + metrics.update({f"{holdout}_chrf": CommentChrF(tokenizer) for holdout in ["val", "test"]}) + self._metrics = MetricCollection(metrics) + + self._encoder = self._get_encoder(model_config) + self._decoder = self.get_decoder(model_config, tokenizer.vocab_size, teacher_forcing) + + self._loss = SequenceCrossEntropyLoss(self._pad_idx, reduction="seq-mean") + + def get_decoder(self, model_config: DictConfig, vocab_size: int, teacher_forcing: float) -> torch.nn.Module: + if model_config.decoder_type == "LSTM": + decoder_step = LSTMDecoderStep(model_config, vocab_size, self._pad_idx) + return Decoder(decoder_step, vocab_size, self._sos_idx, teacher_forcing) + elif model_config.decoder_type == "Transformer": + return TransformerCommentDecoder( + model_config, + vocab_size=vocab_size, + pad_token=self._pad_idx, + sos_token=self._sos_idx, + eos_token=self._eos_idx, + teacher_forcing=teacher_forcing, + ) + else: + raise ValueError + + def _shared_step(self, batch: BatchedLabeledPathContext, step: str) -> Dict: + target_sequence = batch.labels if step != "test" else None + # [seq length; batch size; vocab size] + logits, _ = self.logits_from_batch(batch, target_sequence) + logits = logits[:-1] + batch.labels = batch.labels[1:] + result = {f"{step}/loss": self._loss(logits, batch.labels)} + + with torch.no_grad(): + if step != "train": + prediction = logits.argmax(-1) + metric: ClassificationMetrics = self._metrics[f"{step}_f1"](prediction, batch.labels) + result.update( + { + f"{step}/f1": metric.f1_score, + f"{step}/precision": metric.precision, + f"{step}/recall": metric.recall, + } + ) + result[f"{step}/chrf"] = self._metrics[f"{step}_chrf"](prediction, batch.labels) + else: + result.update({f"{step}/f1": 0, f"{step}/precision": 0, f"{step}/recall": 0}) + + return result diff --git a/code2seq/model/modules/metrics.py b/code2seq/model/modules/metrics.py new file mode 100644 index 0000000..17cc7e2 --- /dev/null +++ b/code2seq/model/modules/metrics.py @@ -0,0 +1,43 @@ +import torch +from torchmetrics import Metric +from sacrebleu import CHRF +from transformers import RobertaTokenizerFast + + +class CommentChrF(Metric): + def __init__(self, tokenizer: RobertaTokenizerFast, **kwargs): + super().__init__(**kwargs) + self.__tokenizer = tokenizer + self.__chrf = CHRF() + + # Metric states + self.add_state("chrf", default=torch.tensor(0, dtype=torch.float), dist_reduce_fx="sum") + self.add_state("count", default=torch.tensor(0), dist_reduce_fx="sum") + + def update(self, predicted: torch.Tensor, target: torch.Tensor): + """Calculated ChrF metric on predicted tensor w.r.t. target tensor. + + :param predicted: [pred seq len; batch size] -- tensor with predicted tokens + :param target: [target seq len; batch size] -- tensor with ground truth tokens + :return: + """ + batch_size = target.shape[1] + if predicted.shape[1] != batch_size: + raise ValueError(f"Wrong batch size for prediction (expected: {batch_size}, actual: {predicted.shape[1]})") + + for batch_idx in range(batch_size): + target_seq = [token.item() for token in target[:, batch_idx]] + predicted_seq = [token.item() for token in predicted[:, batch_idx]] + + target_str = self.__tokenizer.decode(target_seq, skip_special_tokens=True) + predicted_str = self.__tokenizer.decode(predicted_seq, skip_special_tokens=True) + + if target_str == "": + # Empty target string mean that the original string encoded only with token + continue + + self.chrf += self.__chrf.sentence_score(predicted_str, [target_str]).score + self.count += 1 + + def compute(self) -> torch.Tensor: + return self.chrf / self.count diff --git a/code2seq/model/modules/transformer_comment_decoder.py b/code2seq/model/modules/transformer_comment_decoder.py new file mode 100644 index 0000000..608ff4a --- /dev/null +++ b/code2seq/model/modules/transformer_comment_decoder.py @@ -0,0 +1,127 @@ +import math + +import torch +from commode_utils.training import cut_into_segments +from omegaconf import DictConfig +from torch import nn, Tensor, LongTensor +from torch.nn import Embedding, Linear +from torch.nn.modules.transformer import TransformerDecoderLayer, Transformer, TransformerDecoder +from typing import Tuple + + +class PositionalEncoding(nn.Module): + def __init__(self, emb_size: int, dropout: float, max_token_length: int = 5000): + super(PositionalEncoding, self).__init__() + + den = torch.exp(-torch.arange(0, emb_size, 2) * math.log(10000) / emb_size) + pos = torch.arange(0, max_token_length).reshape(max_token_length, 1) + + pe = torch.zeros((max_token_length, emb_size)) + pe[:, 0::2] = torch.sin(pos * den) + pe[:, 1::2] = torch.cos(pos * den) + pe = pe.unsqueeze(0) + + self.dropout = nn.Dropout(dropout) + self.register_buffer("pe", pe) + + def forward(self, token_embedding: Tensor): + output = token_embedding + self.pe[:, : token_embedding.size(1), :] + return self.dropout(output) + + +class TokenEmbedding(nn.Module): + def __init__(self, vocab_size: int, emb_size): + super(TokenEmbedding, self).__init__() + self.embedding = Embedding(vocab_size, emb_size) + self.emb_size = emb_size + + def forward(self, tokens: Tensor): + return self.embedding(tokens.long()) * math.sqrt(self.emb_size) + + +class TransformerCommentDecoder(nn.Module): + def __init__( + self, + config: DictConfig, + vocab_size: int, + pad_token: int, + sos_token: int, + eos_token: int, + teacher_forcing: float = 0.0, + ): + super().__init__() + self._vocab_size = vocab_size + self._pad_token = pad_token + self._sos_token = sos_token + self._eos_token = eos_token + self._teacher_forcing = teacher_forcing + + self._embedding = TokenEmbedding(vocab_size, config.decoder_size) + self._positional_encoding = PositionalEncoding(config.decoder_size, config.decoder_dropout) + decoder_layer = TransformerDecoderLayer( + d_model=config.decoder_size, + nhead=config.decoder_num_heads, + dim_feedforward=config.decoder_dim_feedforward, + dropout=config.decoder_dropout, + batch_first=True, + ) + self._decoder = TransformerDecoder(decoder_layer, config.decoder_num_layers) + self._linear = Linear(config.decoder_size, vocab_size) + + def decode( + self, target_sequence: Tensor, batched_encoder_output: Tensor, tgt_mask: Tensor, attention_mask: Tensor + ) -> Tensor: + tgt_key_padding_mask = target_sequence == self._pad_token + + embedded = self._embedding(target_sequence) + positionally_encoded = self._positional_encoding(embedded) + decoded = self._decoder( + tgt=positionally_encoded, + memory=batched_encoder_output, + tgt_mask=tgt_mask, + tgt_key_padding_mask=tgt_key_padding_mask, + memory_key_padding_mask=attention_mask, + ) + return self._linear(decoded) + + def forward( + self, + encoder_output: Tensor, + segment_sizes: LongTensor, + output_size: int, + target_sequence: Tensor = None, + ) -> Tuple[Tensor, Tensor]: + device = encoder_output.get_device() + batch_size = segment_sizes.shape[0] + + batched_encoder_output, attention_mask = cut_into_segments(encoder_output, segment_sizes) + # TODO fill attentions with smth good + attentions = batched_encoder_output.new_zeros((output_size, batch_size, attention_mask.shape[1])) + + if target_sequence is not None: + target_sequence = target_sequence.permute(1, 0) + + tgt_mask = (Transformer.generate_square_subsequent_mask(output_size)).to(device) + + output = self.decode(target_sequence, batched_encoder_output, tgt_mask, attention_mask) + else: + with torch.no_grad(): + output = torch.zeros((batch_size, output_size, self._vocab_size)).to(device) + + target_sequence = torch.zeros((batch_size, 1)).to(device) + target_sequence[:, 0] = self._sos_token + is_ended = torch.zeros(batch_size, dtype=torch.bool).to(device) + + for i in range(output_size): + tgt_mask = (Transformer.generate_square_subsequent_mask(i + 1)).to(device) + logits = self.decode(target_sequence, batched_encoder_output, tgt_mask, attention_mask) + + prediction = logits.argmax(-1)[:, i] + target_sequence = torch.cat((target_sequence, prediction.unsqueeze(1)), dim=1) + output[:, i, :] = logits[:, i, :] + + is_ended = torch.logical_or(is_ended, (prediction == self._eos_token)) + if torch.count_nonzero(is_ended).item() == batch_size: + break + + return output.permute(1, 0, 2), attentions diff --git a/code2seq/utils/optimization.py b/code2seq/utils/optimization.py index 7d3e501..7b3183a 100644 --- a/code2seq/utils/optimization.py +++ b/code2seq/utils/optimization.py @@ -29,5 +29,5 @@ def configure_optimizers_alon( optimizer = Adam(parameters, optim_config.lr, weight_decay=optim_config.weight_decay) else: raise ValueError(f"Unknown optimizer name: {optim_config.optimizer}, try one of: Adam, Momentum") - scheduler = LambdaLR(optimizer, lr_lambda=lambda epoch: optim_config.decay_gamma ** epoch) + scheduler = LambdaLR(optimizer, lr_lambda=lambda epoch: optim_config.decay_gamma**epoch) return [optimizer], [scheduler] diff --git a/code2seq/utils/train.py b/code2seq/utils/train.py index 4cca70d..5d552b8 100644 --- a/code2seq/utils/train.py +++ b/code2seq/utils/train.py @@ -1,4 +1,3 @@ -import torch from commode_utils.callbacks import ModelCheckpointWithUploadCallback, PrintEpochResultCallback from omegaconf import DictConfig, OmegaConf from pytorch_lightning import seed_everything, Trainer, LightningModule, LightningDataModule @@ -44,6 +43,7 @@ def train(model: LightningModule, data_module: LightningDataModule, config: Dict log_every_n_steps=params.log_every_n_steps, logger=wandb_logger, gpus=params.gpu, + auto_select_gpus=True, callbacks=[lr_logger, early_stopping_callback, checkpoint_callback, print_epoch_result_callback, progress_bar], resume_from_checkpoint=config.get("checkpoint", None), ) diff --git a/config/comment-code2seq-lstm-java-large.yaml b/config/comment-code2seq-lstm-java-large.yaml new file mode 100644 index 0000000..9173b22 --- /dev/null +++ b/config/comment-code2seq-lstm-java-large.yaml @@ -0,0 +1,65 @@ +data_folder: ./dataset + +checkpoint: null + +seed: 7 +# Training in notebooks (e.g. Google Colab) may crash with too small value +progress_bar_refresh_rate: 1 +print_config: true + +wandb: + project: comment-code2seq + group: null + offline: true + +data: + num_workers: 4 + + base_tokenizer: "microsoft/codebert-base" + train_new_tokenizer: true + max_tokenizer_vocab: 20000 + + # Each token appears at least 10 times (99.2% coverage) + labels_count: 10 + max_label_parts: 256 + # Each token appears at least 1000 times (99.5% coverage) + tokens_count: 1000 + max_token_parts: 5 + path_length: 9 + + max_context: 200 + random_context: true + + batch_size: 128 + test_batch_size: 128 + +model: + # Encoder + embedding_size: 128 + encoder_dropout: 0.25 + encoder_rnn_size: 128 + use_bi_rnn: true + rnn_num_layers: 1 + + # Decoder + decoder_type: "LSTM" + decoder_size: 320 + decoder_num_layers: 1 + rnn_dropout: 0.5 + +optimizer: + optimizer: "Momentum" + nesterov: true + lr: 0.01 + weight_decay: 0 + decay_gamma: 0.95 + +train: + gpu: 1 + n_epochs: 100 + patience: 10 + clip_norm: 5 + teacher_forcing: 1.0 + val_every_epoch: 1 + save_every_epoch: 1 + log_every_n_steps: 10 \ No newline at end of file diff --git a/config/comment-code2seq-transformer-java-large.yaml b/config/comment-code2seq-transformer-java-large.yaml new file mode 100644 index 0000000..548f0ea --- /dev/null +++ b/config/comment-code2seq-transformer-java-large.yaml @@ -0,0 +1,67 @@ +data_folder: ./dataset + +checkpoint: null + +seed: 7 +# Training in notebooks (e.g. Google Colab) may crash with too small value +progress_bar_refresh_rate: 1 +print_config: true + +wandb: + project: comment-code2seq + group: null + offline: false + +data: + num_workers: 4 + + base_tokenizer: "microsoft/codebert-base" + train_new_tokenizer: true + max_tokenizer_vocab: 20000 + + # Each token appears at least 10 times (99.2% coverage) + labels_count: 10 + max_label_parts: 256 + # Each token appears at least 1000 times (99.5% coverage) + tokens_count: 1000 + max_token_parts: 5 + path_length: 9 + + max_context: 200 + random_context: true + + batch_size: 64 + test_batch_size: 64 + +model: + # Encoder + embedding_size: 128 + encoder_dropout: 0.25 + encoder_rnn_size: 128 + use_bi_rnn: true + rnn_num_layers: 1 + + # Decoder + decoder_type: "Transformer" + decoder_size: 512 + decoder_num_layers: 4 + decoder_dim_feedforward: 2048 + decoder_num_heads: 8 + decoder_dropout: 0.1 + +optimizer: + optimizer: "Momentum" + nesterov: true + lr: 0.01 + weight_decay: 0 + decay_gamma: 0.95 + +train: + gpu: 1 + n_epochs: 100 + patience: 10 + clip_norm: 5 + teacher_forcing: 1.0 + val_every_epoch: 1 + save_every_epoch: 1 + log_every_n_steps: 10 diff --git a/requirements.txt b/requirements.txt index bad2311..83caa6f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,11 @@ -torch==1.10.0 -pytorch-lightning==1.5.1 -torchmetrics==0.6.0 -tqdm==4.62.3 -wandb==0.12.6 -omegaconf==2.1.1 -commode-utils==0.4.1 +torch==1.12.0 +pytorch-lightning==1.6.5 +torchmetrics==0.9.2 +tqdm==4.64.0 +wandb==0.12.21 +omegaconf==2.2.2 +commode-utils==0.4.2 +typing==3.7.4.3 +transformers==4.20.1 +setuptools==63.2.0 +sacrebleu>=2.0.0 \ No newline at end of file