Skip to content

Latest commit

 

History

History
79 lines (56 loc) · 3.47 KB

README.md

File metadata and controls

79 lines (56 loc) · 3.47 KB

code2seq

JetBrains Research Github action: build Code style: black

PyTorch's implementation of code2seq model.

Installation

You can easily install model through the PIP:

pip install code2seq

Dataset mining

To prepare your own dataset with a storage format supported by this implementation, use on the following:

  1. Original dataset preprocessing from vanilla repository
  2. astminer: the tool for mining path-based representation and more with multiple language support.
  3. PSIMiner: the tool for extracting PSI trees from IntelliJ Platform and creating datasets from them.

Available checkpoints

Method name prediction

Dataset (with link) Checkpoint # epochs F1-score Precision Recall ChrF
Java-small link 11 41.49 54.26 33.59 30.21
Java-med link 10 48.17 58.87 40.76 42.32

Configuration

The model is fully configurable by standalone YAML file. Navigate to config directory to see examples of configs.

Examples

Model training may be done via PyTorch Lightning trainer. See it documentation for more information.

from argparse import ArgumentParser

from omegaconf import DictConfig, OmegaConf
from pytorch_lightning import Trainer

from code2seq.data.path_context_data_module import PathContextDataModule
from code2seq.model import Code2Seq


def train(config: DictConfig):
    # Define data module
    data_module = PathContextDataModule(config.data_folder, config.data)

    # Define model
    model = Code2Seq(
        config.model,
        config.optimizer,
        data_module.vocabulary,
        config.train.teacher_forcing
    )

    # Define hyper parameters
    trainer = Trainer(max_epochs=config.train.n_epochs)

    # Train model
    trainer.fit(model, datamodule=data_module)


if __name__ == "__main__":
    __arg_parser = ArgumentParser()
    __arg_parser.add_argument("config", help="Path to YAML configuration file", type=str)
    __args = __arg_parser.parse_args()

    __config = OmegaConf.load(__args.config)
    train(__config)