Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Comment prediction added (CommentCode2Seq) #120

Open
wants to merge 33 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 12 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
628c5e5
Added useful classes
malodetz Jul 14, 2022
1ff1994
Config added
malodetz Jul 14, 2022
f8ffa00
Added comment label processing
malodetz Jul 15, 2022
5f39e68
Added wrapper
malodetz Jul 15, 2022
8620414
Update requirements.txt
malodetz Jul 15, 2022
7624d44
Fixing train to use gpu
malodetz Jul 15, 2022
fb5fb26
New model with correct decoder
malodetz Jul 15, 2022
c4a203b
Fix black
malodetz Jul 15, 2022
6fdee03
Added custom chrf metric
malodetz Jul 15, 2022
db70270
Minor updates
malodetz Jul 16, 2022
b68cc10
Fix random
malodetz Jul 16, 2022
24c8484
Fixing chrf and f1
malodetz Jul 20, 2022
33db6b5
Preliminary new tokenizer
malodetz Jul 28, 2022
dc15cdb
Complete new tokenizer
malodetz Jul 28, 2022
26dd42f
Some fixes
malodetz Jul 28, 2022
f6fa424
New vocab size
malodetz Jul 28, 2022
e8b677d
Add tokenizer to config
malodetz Aug 7, 2022
b1b2e27
Move chrf metric
malodetz Aug 8, 2022
e70f96d
Better tokenization
malodetz Aug 8, 2022
5e93981
Implement comment transformer decoder
malodetz Aug 10, 2022
033560c
Greedy decoding for val/test
malodetz Aug 10, 2022
c74319f
Small fix
malodetz Aug 10, 2022
b1009cd
Some fixes
malodetz Aug 12, 2022
6aff1ed
Logits cut
malodetz Aug 12, 2022
c23e20a
Fixed greedy generation
malodetz Aug 12, 2022
37f3db2
Some train changes to fix
malodetz Aug 12, 2022
a19e752
Fix train
malodetz Aug 25, 2022
e43b664
Сonfig for transformer decoder
malodetz Aug 30, 2022
c51f14c
Multiple decoders
malodetz Aug 30, 2022
d09dc30
Another config
malodetz Aug 30, 2022
9bfa25c
Add early stop for greedy generation
malodetz Aug 30, 2022
b0b7b81
Early generation stop
malodetz Aug 30, 2022
54cb711
Added predictions
malodetz Sep 17, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 63 additions & 0 deletions code2seq/comment_code2seq_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
from argparse import ArgumentParser
from typing import cast

import torch
from commode_utils.common import print_config
from omegaconf import DictConfig, OmegaConf
from pytorch_lightning import seed_everything

from code2seq.data.comment_path_context_data_module import CommentPathContextDataModule
from code2seq.model.comment_code2seq import CommentCode2Seq
from code2seq.utils.common import filter_warnings
from code2seq.utils.test import test
from code2seq.utils.train import train


def configure_arg_parser() -> ArgumentParser:
arg_parser = ArgumentParser()
arg_parser.add_argument("mode", help="Mode to run script", choices=["train", "test"])
arg_parser.add_argument("-c", "--config", help="Path to YAML configuration file", type=str)
arg_parser.add_argument(
"-p", "--pretrained", help="Path to pretrained model", type=str, required=False, default=None
)
return arg_parser


def train_code2seq(config: DictConfig):
filter_warnings()

if config.print_config:
print_config(config, fields=["model", "data", "train", "optimizer"])

# Load data module
data_module = CommentPathContextDataModule(config.data_folder, config.data)

# Load model
code2seq = CommentCode2Seq(config.model, config.optimizer, data_module.vocabulary, config.train.teacher_forcing)

train(code2seq, data_module, config)


def test_code2seq(model_path: str, config: DictConfig):
filter_warnings()

# Load data module
data_module = CommentPathContextDataModule(config.data_folder, config.data)

# Load model
code2seq = CommentCode2Seq.load_from_checkpoint(model_path, map_location=torch.device("cpu"))

test(code2seq, data_module, config.seed)


if __name__ == "__main__":
__arg_parser = configure_arg_parser()
__args = __arg_parser.parse_args()

__config = cast(DictConfig, OmegaConf.load(__args.config))
seed_everything(__config.seed)
if __args.mode == "train":
train_code2seq(__config)
else:
assert __args.pretrained is not None
test_code2seq(__args.pretrained, __config)
9 changes: 9 additions & 0 deletions code2seq/data/comment_path_context_data_module.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from code2seq.data.comment_path_context_dataset import CommentPathContextDataset
from code2seq.data.path_context_data_module import PathContextDataModule


class CommentPathContextDataModule(PathContextDataModule):
def _create_dataset(self, holdout_file: str, random_context: bool) -> CommentPathContextDataset:
if self._vocabulary is None:
raise RuntimeError(f"Setup vocabulary before creating data loaders")
return CommentPathContextDataset(holdout_file, self._config, self._vocabulary, random_context)
19 changes: 19 additions & 0 deletions code2seq/data/comment_path_context_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from typing import Dict, List, Optional

from transformers import RobertaTokenizerFast

from code2seq.data.path_context_dataset import PathContextDataset

tokenizer = RobertaTokenizerFast.from_pretrained("microsoft/codebert-base")
malodetz marked this conversation as resolved.
Show resolved Hide resolved


class CommentPathContextDataset(PathContextDataset):
@staticmethod
def tokenize_label(raw_label: str, vocab: Dict[str, int], max_parts: Optional[int]) -> List[int]:
label_with_spaces = " ".join(raw_label.split(PathContextDataset._separator))
malodetz marked this conversation as resolved.
Show resolved Hide resolved
label_tokens = tokenizer.tokenize(label_with_spaces)
if max_parts is None:
max_parts = len(label_tokens)
label_tokens = [tokenizer.bos_token] + label_tokens[: max_parts - 2] + [tokenizer.eos_token]
label_tokens += [tokenizer.pad_token] * (max_parts - len(label_tokens))
malodetz marked this conversation as resolved.
Show resolved Hide resolved
return tokenizer.convert_tokens_to_ids(label_tokens)
28 changes: 15 additions & 13 deletions code2seq/model/code2seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,26 +35,26 @@ def __init__(
if vocabulary.SOS not in vocabulary.label_to_id:
raise ValueError(f"Can't find SOS token in label to id vocabulary")

self.__pad_idx = vocabulary.label_to_id[vocabulary.PAD]
self._pad_idx = vocabulary.label_to_id[vocabulary.PAD]
eos_idx = vocabulary.label_to_id[vocabulary.EOS]
ignore_idx = [vocabulary.label_to_id[vocabulary.SOS], vocabulary.label_to_id[vocabulary.UNK]]
metrics: Dict[str, Metric] = {
f"{holdout}_f1": SequentialF1Score(pad_idx=self.__pad_idx, eos_idx=eos_idx, ignore_idx=ignore_idx)
f"{holdout}_f1": SequentialF1Score(pad_idx=self._pad_idx, eos_idx=eos_idx, ignore_idx=ignore_idx)
for holdout in ["train", "val", "test"]
}
id2label = {v: k for k, v in vocabulary.label_to_id.items()}
metrics.update(
{f"{holdout}_chrf": ChrF(id2label, ignore_idx + [self.__pad_idx, eos_idx]) for holdout in ["val", "test"]}
{f"{holdout}_chrf": ChrF(id2label, ignore_idx + [self._pad_idx, eos_idx]) for holdout in ["val", "test"]}
)
self.__metrics = MetricCollection(metrics)
self._metrics = MetricCollection(metrics)

self._encoder = self._get_encoder(model_config)
decoder_step = LSTMDecoderStep(model_config, len(vocabulary.label_to_id), self.__pad_idx)
decoder_step = LSTMDecoderStep(model_config, len(vocabulary.label_to_id), self._pad_idx)
self._decoder = Decoder(
decoder_step, len(vocabulary.label_to_id), vocabulary.label_to_id[vocabulary.SOS], teacher_forcing
)

self.__loss = SequenceCrossEntropyLoss(self.__pad_idx, reduction="batch-mean")
self._loss = SequenceCrossEntropyLoss(self._pad_idx, reduction="batch-mean")

@property
def vocabulary(self) -> Vocabulary:
Expand Down Expand Up @@ -107,16 +107,18 @@ def _shared_step(self, batch: BatchedLabeledPathContext, step: str) -> Dict:
target_sequence = batch.labels if step == "train" else None
# [seq length; batch size; vocab size]
logits, _ = self.logits_from_batch(batch, target_sequence)
result = {f"{step}/loss": self.__loss(logits[1:], batch.labels[1:])}
logits = logits[1:]
batch.labels = batch.labels[1:]
result = {f"{step}/loss": self._loss(logits, batch.labels)}

with torch.no_grad():
prediction = logits.argmax(-1)
metric: ClassificationMetrics = self.__metrics[f"{step}_f1"](prediction, batch.labels)
metric: ClassificationMetrics = self._metrics[f"{step}_f1"](prediction, batch.labels)
result.update(
{f"{step}/f1": metric.f1_score, f"{step}/precision": metric.precision, f"{step}/recall": metric.recall}
)
if step != "train":
result[f"{step}/chrf"] = self.__metrics[f"{step}_chrf"](prediction, batch.labels)
result[f"{step}/chrf"] = self._metrics[f"{step}_chrf"](prediction, batch.labels)

return result

Expand All @@ -140,17 +142,17 @@ def _shared_epoch_end(self, step_outputs: EPOCH_OUTPUT, step: str):
with torch.no_grad():
losses = [so if isinstance(so, torch.Tensor) else so["loss"] for so in step_outputs]
mean_loss = torch.stack(losses).mean()
metric = self.__metrics[f"{step}_f1"].compute()
metric = self._metrics[f"{step}_f1"].compute()
log = {
f"{step}/loss": mean_loss,
f"{step}/f1": metric.f1_score,
f"{step}/precision": metric.precision,
f"{step}/recall": metric.recall,
}
self.__metrics[f"{step}_f1"].reset()
self._metrics[f"{step}_f1"].reset()
if step != "train":
log[f"{step}/chrf"] = self.__metrics[f"{step}_chrf"].compute()
self.__metrics[f"{step}_chrf"].reset()
log[f"{step}/chrf"] = self._metrics[f"{step}_chrf"].compute()
self._metrics[f"{step}_chrf"].reset()
self.log_dict(log, on_step=False, on_epoch=True)

def training_epoch_end(self, step_outputs: EPOCH_OUTPUT):
Expand Down
84 changes: 84 additions & 0 deletions code2seq/model/comment_code2seq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
from typing import Dict

import torch
from commode_utils.losses import SequenceCrossEntropyLoss
from commode_utils.metrics import SequentialF1Score
from commode_utils.modules import LSTMDecoderStep, Decoder
from omegaconf import DictConfig
from sacrebleu import CHRF
from torchmetrics import MetricCollection, Metric
from transformers import RobertaTokenizerFast

from code2seq.data.vocabulary import Vocabulary
from code2seq.model import Code2Seq


class CommentCode2Seq(Code2Seq):
def __init__(
self,
model_config: DictConfig,
optimizer_config: DictConfig,
vocabulary: Vocabulary,
teacher_forcing: float = 0.0,
):
super().__init__(model_config, optimizer_config, vocabulary, teacher_forcing)

tokenizer = RobertaTokenizerFast.from_pretrained("microsoft/codebert-base")

self._pad_idx = tokenizer.pad_token_id
eos_idx = tokenizer.eos_token_id
ignore_idx = [tokenizer.bos_token_id, tokenizer.unk_token_id]
metrics: Dict[str, Metric] = {
f"{holdout}_f1": SequentialF1Score(pad_idx=self._pad_idx, eos_idx=eos_idx, ignore_idx=ignore_idx)
for holdout in ["train", "val", "test"]
}

# TODO add concatenation and rouge-L metric
metrics.update({f"{holdout}_chrf": CommentChrF(tokenizer) for holdout in ["val", "test"]})
self._metrics = MetricCollection(metrics)

self._encoder = self._get_encoder(model_config)
output_size = len(tokenizer.get_vocab())
decoder_step = LSTMDecoderStep(model_config, output_size, self._pad_idx)
self._decoder = Decoder(decoder_step, output_size, tokenizer.eos_token_id, teacher_forcing)

self._loss = SequenceCrossEntropyLoss(self._pad_idx, reduction="batch-mean")


class CommentChrF(Metric):
malodetz marked this conversation as resolved.
Show resolved Hide resolved
def __init__(self, tokenizer: RobertaTokenizerFast, **kwargs):
super().__init__(**kwargs)
self.__tokenizer = tokenizer
self.__chrf = CHRF()

# Metric states
self.add_state("chrf", default=torch.tensor(0, dtype=torch.float), dist_reduce_fx="sum")
self.add_state("count", default=torch.tensor(0), dist_reduce_fx="sum")

def update(self, predicted: torch.Tensor, target: torch.Tensor):
"""Calculated ChrF metric on predicted tensor w.r.t. target tensor.

:param predicted: [pred seq len; batch size] -- tensor with predicted tokens
:param target: [target seq len; batch size] -- tensor with ground truth tokens
:return:
"""
batch_size = target.shape[1]
if predicted.shape[1] != batch_size:
raise ValueError(f"Wrong batch size for prediction (expected: {batch_size}, actual: {predicted.shape[1]})")

for batch_idx in range(batch_size):
target_seq = [token.item() for token in target[:, batch_idx]]
predicted_seq = [token.item() for token in predicted[:, batch_idx]]

target_str = self.__tokenizer.decode(target_seq, skip_special_tokens=True)
predicted_str = self.__tokenizer.decode(predicted_seq, skip_special_tokens=True)

if target_str == "":
# Empty target string mean that the original string encoded only with <UNK> token
continue

self.chrf += self.__chrf.sentence_score(predicted_str, [target_str]).score
self.count += 1

def compute(self) -> torch.Tensor:
return self.chrf / self.count
2 changes: 1 addition & 1 deletion code2seq/utils/optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,5 +29,5 @@ def configure_optimizers_alon(
optimizer = Adam(parameters, optim_config.lr, weight_decay=optim_config.weight_decay)
else:
raise ValueError(f"Unknown optimizer name: {optim_config.optimizer}, try one of: Adam, Momentum")
scheduler = LambdaLR(optimizer, lr_lambda=lambda epoch: optim_config.decay_gamma ** epoch)
scheduler = LambdaLR(optimizer, lr_lambda=lambda epoch: optim_config.decay_gamma**epoch)
return [optimizer], [scheduler]
1 change: 1 addition & 0 deletions code2seq/utils/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def train(model: LightningModule, data_module: LightningDataModule, config: Dict
log_every_n_steps=params.log_every_n_steps,
logger=wandb_logger,
gpus=params.gpu,
auto_select_gpus=True,
callbacks=[lr_logger, early_stopping_callback, checkpoint_callback, print_epoch_result_callback, progress_bar],
resume_from_checkpoint=config.get("checkpoint", None),
)
Expand Down
60 changes: 60 additions & 0 deletions config/comment-code2seq-java-large.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
data_folder: ./dataset

checkpoint: null

seed: 7
# Training in notebooks (e.g. Google Colab) may crash with too small value
progress_bar_refresh_rate: 1
print_config: true

wandb:
project: Code2Seq -- java-med
group: null
offline: false

data:
num_workers: 4

# Each token appears at least 10 times (99.2% coverage)
labels_count: 1
max_label_parts: 256
# Each token appears at least 1000 times (99.5% coverage)
tokens_count: 1
max_token_parts: 5
path_length: 9

max_context: 200
random_context: true

batch_size: 16
test_batch_size: 16

model:
# Encoder
embedding_size: 128
encoder_dropout: 0.25
encoder_rnn_size: 128
use_bi_rnn: true
rnn_num_layers: 1

# Decoder
decoder_size: 320
decoder_num_layers: 1
rnn_dropout: 0.5

optimizer:
optimizer: "Momentum"
nesterov: true
lr: 0.01
weight_decay: 0
decay_gamma: 0.95

train:
gpu: 1
n_epochs: 100
patience: 10
clip_norm: 5
teacher_forcing: 1.0
val_every_epoch: 1
save_every_epoch: 1
log_every_n_steps: 10
18 changes: 11 additions & 7 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
torch==1.10.0
pytorch-lightning==1.5.1
torchmetrics==0.6.0
tqdm==4.62.3
wandb==0.12.6
omegaconf==2.1.1
commode-utils==0.4.1
torch==1.12.0
pytorch-lightning==1.6.5
torchmetrics==0.9.2
tqdm==4.64.0
wandb==0.12.21
omegaconf==2.2.2
commode-utils==0.4.2
typing==3.7.4.3
transformers==4.20.1
setuptools==63.2.0
sacrebleu>=2.0.0