Skip to content

Commit

Permalink
Merge pull request #108 from JetBrains-Research/integrate_commode_utils
Browse files Browse the repository at this point in the history
Integrate commode utils
  • Loading branch information
SpirinEgor authored Aug 22, 2021
2 parents 88b1a32 + ef8229d commit e678f8b
Show file tree
Hide file tree
Showing 69 changed files with 1,407 additions and 5,399 deletions.
1 change: 0 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ __pycache__/
*.py[cod]
*$py.class

data/
wandb/
notebooks/
outputs/
Expand Down
40 changes: 26 additions & 14 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,31 +17,43 @@ pip install code2seq
## Usage

Minimal code example to run the model:

```python
from os.path import join
from argparse import ArgumentParser

import hydra
from code2seq.dataset import PathContextDataModule
from code2seq.model import Code2Seq
from code2seq.utils.vocabulary import Vocabulary
from omegaconf import DictConfig
from omegaconf import DictConfig, OmegaConf
from pytorch_lightning import Trainer

from code2seq.data.path_context_data_module import PathContextDataModule
from code2seq.model import Code2Seq


@hydra.main(config_path="configs")
def train(config: DictConfig):
vocabulary_path = join(config.data_folder, config.dataset.name, config.vocabulary_name)
vocabulary = Vocabulary.load_vocabulary(vocabulary_path)
model = Code2Seq(config, vocabulary)
data_module = PathContextDataModule(config, vocabulary)
# Load data module
data_module = PathContextDataModule(config.data_folder, config.data)
data_module.prepare_data()
data_module.setup()

# Load model
model = Code2Seq(
config.model,
config.optimizer,
data_module.vocabulary,
config.train.teacher_forcing
)

trainer = Trainer(max_epochs=config.hyper_parameters.n_epochs)
trainer.fit(model, datamodule=data_module)


if __name__ == "__main__":
train()
__arg_parser = ArgumentParser()
__arg_parser.add_argument("config", help="Path to YAML configuration file", type=str)
__args = __arg_parser.parse_args()

__config = OmegaConf.load(__args.config)
train(__config)
```

Navigate to [code2seq/configs](code2seq/configs) to see examples of configs.
If you had any questions then feel free to open the issue.
Navigate to [config](config) directory to see examples of configs.
If you have any questions, then feel free to open the issue.
61 changes: 61 additions & 0 deletions code2seq/code2class_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
from argparse import ArgumentParser
from typing import cast

import torch
from commode_utils.common import print_config
from omegaconf import DictConfig, OmegaConf

from code2seq.data.path_context_data_module import PathContextDataModule
from code2seq.model import Code2Class
from code2seq.utils.common import filter_warnings
from code2seq.utils.test import test
from code2seq.utils.train import train


def configure_arg_parser() -> ArgumentParser:
arg_parser = ArgumentParser()
arg_parser.add_argument("mode", help="Mode to run script", choices=["train", "test"])
arg_parser.add_argument("-c", "--config", help="Path to YAML configuration file", type=str)
return arg_parser


def train_code2class(config: DictConfig):
filter_warnings()

if config.print_config:
print_config(config, fields=["model", "data", "train", "optimizer"])

# Load data module
data_module = PathContextDataModule(config.data_folder, config.data, is_class=True)
data_module.prepare_data()
data_module.setup()

# Load model
code2class = Code2Class(config.model, config.optimizer, data_module.vocabulary)

train(code2class, data_module, config)


def test_code2class(config: DictConfig):
filter_warnings()

# Load data module
data_module = PathContextDataModule(config.data_folder, config.data)
data_module.prepare_data()
data_module.setup()

# Load model
code2class = Code2Class.load_from_checkpoint(config.checkpoint, map_location=torch.device("cpu"))

test(code2class, data_module, config.seed)


if __name__ == "__main__":
__arg_parser = configure_arg_parser()
__args = __arg_parser.parse_args()

__config = cast(DictConfig, OmegaConf.load(__args.config))
if __args.mode == "train":
train_code2class(__config)
else:
test_code2class(__config)
61 changes: 61 additions & 0 deletions code2seq/code2seq_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
from argparse import ArgumentParser
from typing import cast

import torch
from commode_utils.common import print_config
from omegaconf import DictConfig, OmegaConf

from code2seq.data.path_context_data_module import PathContextDataModule
from code2seq.model import Code2Seq
from code2seq.utils.common import filter_warnings
from code2seq.utils.test import test
from code2seq.utils.train import train


def configure_arg_parser() -> ArgumentParser:
arg_parser = ArgumentParser()
arg_parser.add_argument("mode", help="Mode to run script", choices=["train", "test"])
arg_parser.add_argument("-c", "--config", help="Path to YAML configuration file", type=str)
return arg_parser


def train_code2seq(config: DictConfig):
filter_warnings()

if config.print_config:
print_config(config, fields=["model", "data", "train", "optimizer"])

# Load data module
data_module = PathContextDataModule(config.data_folder, config.data)
data_module.prepare_data()
data_module.setup()

# Load model
code2seq = Code2Seq(config.model, config.optimizer, data_module.vocabulary, config.train.teacher_forcing)

train(code2seq, data_module, config)


def test_code2seq(config: DictConfig):
filter_warnings()

# Load data module
data_module = PathContextDataModule(config.data_folder, config.data)
data_module.prepare_data()
data_module.setup()

# Load model
code2seq = Code2Seq.load_from_checkpoint(config.checkpoint, map_location=torch.device("cpu"))

test(code2seq, data_module, config.seed)


if __name__ == "__main__":
__arg_parser = configure_arg_parser()
__args = __arg_parser.parse_args()

__config = cast(DictConfig, OmegaConf.load(__args.config))
if __args.mode == "train":
train_code2seq(__config)
else:
test_code2seq(__config)
72 changes: 0 additions & 72 deletions code2seq/configs/code2class-poj104.yaml

This file was deleted.

76 changes: 0 additions & 76 deletions code2seq/configs/code2seq-java-small.yaml

This file was deleted.

Loading

0 comments on commit e678f8b

Please sign in to comment.