Skip to content

Commit

Permalink
(add) vignere with all possible settings on news dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
JanProvaznik committed Oct 23, 2023
1 parent 3ed3273 commit 89c8288
Showing 1 changed file with 284 additions and 0 deletions.
284 changes: 284 additions & 0 deletions reproducible/13_vignere_random_news.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,284 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "be0ed4e9",
"metadata": {},
"source": [
"\n",
"# Vignere cipher (all possible settings, length 2) on news dataset"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e0754331",
"metadata": {},
"outputs": [],
"source": [
"\n",
"\n",
"# import src to path\n",
"import sys\n",
"import os\n",
"\n",
"sys.path.append(\"./enigma-transformed/src\")\n",
"sys.path.append(\"./src\")\n",
"sys.path.append(\"../src\")\n",
"sys.path.append(\"../../src\")\n",
"\n",
"if __name__ == \"__main__\":\n",
" # try get SLURM JOB ID\n",
" try:\n",
" job_id = os.environ[\"SLURM_JOB_ID\"]\n",
" except:\n",
" job_id = \"debug\"\n",
" logdir = f\"logs/slurm_{job_id}\"\n",
" os.makedirs(logdir, exist_ok=True)\n",
"\n",
"\n",
"# ## Setup and hyperparameters"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b7d65cba",
"metadata": {},
"outputs": [],
"source": [
"\n",
"\n",
"from utils import calculate_batch_size\n",
"dataset_size = 20000\n",
"dataset_min_len = 100\n",
"dataset_max_len = 100\n",
"seed = 39 # reproducible\n",
"evaluate_on_test = True \n",
"device = 'cuda:0'\n",
"train_epochs = 70\n",
"lr = 2e-3\n",
"warmup_ratio = .2\n",
"\n",
"tartget_batch_size = 256\n",
"batch_size, grad_acc_steps = calculate_batch_size(tartget_batch_size, dataset_max_len)\n",
"\n"
]
},
{
"cell_type": "markdown",
"id": "db0aea5f",
"metadata": {},
"source": [
"\n",
"## Data"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e82c877d",
"metadata": {},
"outputs": [],
"source": [
"# 0. (optional) get data and preprocess it\n",
"import os\n",
"import utils\n",
"from preprocessing import preprocess_file\n",
"\n",
"data_path = 'news.2012.en.shuffled.deduped'\n",
"if not os.path.exists(data_path):\n",
" utils.download_newscrawl(2012,'en')\n",
" preprocess_file(data_path)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1030af22",
"metadata": {},
"outputs": [],
"source": [
"import ByT5Dataset\n",
"import torch.utils.data\n",
"from preprocessing import load_dataset\n",
"\n",
"dataset = load_dataset(dataset_size, dataset_min_len, dataset_max_len, data_path, seed)\n",
"generator1 = torch.Generator().manual_seed(seed)\n",
"train_ex, dev_ex, test_ex = torch.utils.data.random_split(\n",
" dataset,\n",
" [round(0.8 * dataset_size), round(0.1 * dataset_size), round(0.1 * dataset_size)],\n",
" generator=generator1,\n",
")\n",
"train = ByT5Dataset.ByT5Vignere2RandomDataset(train_ex, max_length=dataset_max_len)\n",
"dev = ByT5Dataset.ByT5Vignere2RandomDataset(dev_ex, max_length=dataset_max_len)\n",
"test = ByT5Dataset.ByT5Vignere2RandomDataset(test_ex, max_length=dataset_max_len)\n",
"\n"
]
},
{
"cell_type": "markdown",
"id": "39cb3603",
"metadata": {},
"source": [
"\n",
"## Model architecture"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b130d52d",
"metadata": {},
"outputs": [],
"source": [
"\n",
"\n",
"# We want a T5 architecutre but severely reduced in size\n",
"from transformers import ByT5Tokenizer, AutoModelForSeq2SeqLM\n",
"\n",
"tokenizer = ByT5Tokenizer()\n",
"model = AutoModelForSeq2SeqLM.from_pretrained(\"google/byt5-small\")\n",
"\n"
]
},
{
"cell_type": "markdown",
"id": "841a87e6",
"metadata": {},
"source": [
"\n",
"## Training setup"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f3538f36",
"metadata": {},
"outputs": [],
"source": [
"\n",
"\n",
"from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments\n",
"from transformers import (\n",
" DataCollatorForSeq2Seq,\n",
" Seq2SeqTrainer,\n",
" Seq2SeqTrainingArguments,\n",
")\n",
"\n",
"data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)\n",
"training_args = Seq2SeqTrainingArguments(\n",
" output_dir=logdir + \"/output\",\n",
" evaluation_strategy=\"epoch\",\n",
" num_train_epochs=train_epochs,\n",
" per_device_train_batch_size=batch_size,\n",
" per_device_eval_batch_size=batch_size,\n",
" # accumulate gradients to simulate higher batch size\n",
" gradient_accumulation_steps=grad_acc_steps,\n",
" save_total_limit=0,\n",
" predict_with_generate=True,\n",
" push_to_hub=False,\n",
" logging_dir=logdir,\n",
" learning_rate=lr,\n",
" warmup_ratio=warmup_ratio,\n",
" save_steps=10000,\n",
")\n",
"\n",
"\n",
"# ## Training"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a5bdf4a9",
"metadata": {},
"outputs": [],
"source": [
"trainer = Seq2SeqTrainer(\n",
" model=model,\n",
" args=training_args,\n",
" train_dataset=train,\n",
" eval_dataset=dev,\n",
" data_collator=data_collator,\n",
" tokenizer=tokenizer,\n",
")\n",
"\n",
"trainer.train()\n",
"trainer.save_model(logdir + \"/model\")"
]
},
{
"cell_type": "markdown",
"id": "851bf9e3",
"metadata": {},
"source": [
"\n",
"## Evaluation"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b635a4ea",
"metadata": {
"lines_to_next_cell": 2
},
"outputs": [],
"source": [
"\n",
"if evaluate_on_test:\n",
" pass\n",
"else:\n",
" test = dev"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ddae00df",
"metadata": {
"lines_to_next_cell": 2
},
"outputs": [],
"source": [
"from utils import levensthein_distance, print_avg_median_mode_error\n",
"from transformers import pipeline, logging\n",
"logging.set_verbosity(logging.ERROR)\n",
"\n",
"\n",
"error_counts = []\n",
"translate = pipeline(\"translation\", model=model, tokenizer=tokenizer, device=device)\n",
"for index in range(len(test)):\n",
" generated = translate(test[index][\"input_text\"], max_length=(dataset_max_len+1)*2)[0][\"translation_text\"]\n",
" error_counts.append(levensthein_distance(generated, test[index][\"output_text\"]))\n",
" if error_counts[-1] > 0:\n",
" print(f\"Example {index}, error count {error_counts[-1]}\")\n",
" print(\"In :\", test[index][\"input_text\"])\n",
" print(\"Gen:\", generated)\n",
" expected = test[index][\"output_text\"]\n",
" print(\"Exp:\", expected)\n",
" else:\n",
" print(f\"Example {index} OK\")\n",
" print(\"-----------------------\")\n",
"\n",
"print_avg_median_mode_error(error_counts)"
]
}
],
"metadata": {
"jupytext": {
"cell_metadata_filter": "-all",
"encoding": "# coding: utf-8",
"executable": "/usr/bin/env python",
"main_language": "python",
"notebook_metadata_filter": "-all"
},
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

0 comments on commit 89c8288

Please sign in to comment.