-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
(add) vignere with all possible settings on news dataset
- Loading branch information
1 parent
3ed3273
commit 89c8288
Showing
1 changed file
with
284 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,284 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"id": "be0ed4e9", | ||
"metadata": {}, | ||
"source": [ | ||
"\n", | ||
"# Vignere cipher (all possible settings, length 2) on news dataset" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "e0754331", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"\n", | ||
"\n", | ||
"# import src to path\n", | ||
"import sys\n", | ||
"import os\n", | ||
"\n", | ||
"sys.path.append(\"./enigma-transformed/src\")\n", | ||
"sys.path.append(\"./src\")\n", | ||
"sys.path.append(\"../src\")\n", | ||
"sys.path.append(\"../../src\")\n", | ||
"\n", | ||
"if __name__ == \"__main__\":\n", | ||
" # try get SLURM JOB ID\n", | ||
" try:\n", | ||
" job_id = os.environ[\"SLURM_JOB_ID\"]\n", | ||
" except:\n", | ||
" job_id = \"debug\"\n", | ||
" logdir = f\"logs/slurm_{job_id}\"\n", | ||
" os.makedirs(logdir, exist_ok=True)\n", | ||
"\n", | ||
"\n", | ||
"# ## Setup and hyperparameters" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "b7d65cba", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"\n", | ||
"\n", | ||
"from utils import calculate_batch_size\n", | ||
"dataset_size = 20000\n", | ||
"dataset_min_len = 100\n", | ||
"dataset_max_len = 100\n", | ||
"seed = 39 # reproducible\n", | ||
"evaluate_on_test = True \n", | ||
"device = 'cuda:0'\n", | ||
"train_epochs = 70\n", | ||
"lr = 2e-3\n", | ||
"warmup_ratio = .2\n", | ||
"\n", | ||
"tartget_batch_size = 256\n", | ||
"batch_size, grad_acc_steps = calculate_batch_size(tartget_batch_size, dataset_max_len)\n", | ||
"\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "db0aea5f", | ||
"metadata": {}, | ||
"source": [ | ||
"\n", | ||
"## Data" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "e82c877d", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# 0. (optional) get data and preprocess it\n", | ||
"import os\n", | ||
"import utils\n", | ||
"from preprocessing import preprocess_file\n", | ||
"\n", | ||
"data_path = 'news.2012.en.shuffled.deduped'\n", | ||
"if not os.path.exists(data_path):\n", | ||
" utils.download_newscrawl(2012,'en')\n", | ||
" preprocess_file(data_path)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "1030af22", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import ByT5Dataset\n", | ||
"import torch.utils.data\n", | ||
"from preprocessing import load_dataset\n", | ||
"\n", | ||
"dataset = load_dataset(dataset_size, dataset_min_len, dataset_max_len, data_path, seed)\n", | ||
"generator1 = torch.Generator().manual_seed(seed)\n", | ||
"train_ex, dev_ex, test_ex = torch.utils.data.random_split(\n", | ||
" dataset,\n", | ||
" [round(0.8 * dataset_size), round(0.1 * dataset_size), round(0.1 * dataset_size)],\n", | ||
" generator=generator1,\n", | ||
")\n", | ||
"train = ByT5Dataset.ByT5Vignere2RandomDataset(train_ex, max_length=dataset_max_len)\n", | ||
"dev = ByT5Dataset.ByT5Vignere2RandomDataset(dev_ex, max_length=dataset_max_len)\n", | ||
"test = ByT5Dataset.ByT5Vignere2RandomDataset(test_ex, max_length=dataset_max_len)\n", | ||
"\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "39cb3603", | ||
"metadata": {}, | ||
"source": [ | ||
"\n", | ||
"## Model architecture" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "b130d52d", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"\n", | ||
"\n", | ||
"# We want a T5 architecutre but severely reduced in size\n", | ||
"from transformers import ByT5Tokenizer, AutoModelForSeq2SeqLM\n", | ||
"\n", | ||
"tokenizer = ByT5Tokenizer()\n", | ||
"model = AutoModelForSeq2SeqLM.from_pretrained(\"google/byt5-small\")\n", | ||
"\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "841a87e6", | ||
"metadata": {}, | ||
"source": [ | ||
"\n", | ||
"## Training setup" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "f3538f36", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"\n", | ||
"\n", | ||
"from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments\n", | ||
"from transformers import (\n", | ||
" DataCollatorForSeq2Seq,\n", | ||
" Seq2SeqTrainer,\n", | ||
" Seq2SeqTrainingArguments,\n", | ||
")\n", | ||
"\n", | ||
"data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)\n", | ||
"training_args = Seq2SeqTrainingArguments(\n", | ||
" output_dir=logdir + \"/output\",\n", | ||
" evaluation_strategy=\"epoch\",\n", | ||
" num_train_epochs=train_epochs,\n", | ||
" per_device_train_batch_size=batch_size,\n", | ||
" per_device_eval_batch_size=batch_size,\n", | ||
" # accumulate gradients to simulate higher batch size\n", | ||
" gradient_accumulation_steps=grad_acc_steps,\n", | ||
" save_total_limit=0,\n", | ||
" predict_with_generate=True,\n", | ||
" push_to_hub=False,\n", | ||
" logging_dir=logdir,\n", | ||
" learning_rate=lr,\n", | ||
" warmup_ratio=warmup_ratio,\n", | ||
" save_steps=10000,\n", | ||
")\n", | ||
"\n", | ||
"\n", | ||
"# ## Training" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "a5bdf4a9", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"trainer = Seq2SeqTrainer(\n", | ||
" model=model,\n", | ||
" args=training_args,\n", | ||
" train_dataset=train,\n", | ||
" eval_dataset=dev,\n", | ||
" data_collator=data_collator,\n", | ||
" tokenizer=tokenizer,\n", | ||
")\n", | ||
"\n", | ||
"trainer.train()\n", | ||
"trainer.save_model(logdir + \"/model\")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "851bf9e3", | ||
"metadata": {}, | ||
"source": [ | ||
"\n", | ||
"## Evaluation" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "b635a4ea", | ||
"metadata": { | ||
"lines_to_next_cell": 2 | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"\n", | ||
"if evaluate_on_test:\n", | ||
" pass\n", | ||
"else:\n", | ||
" test = dev" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "ddae00df", | ||
"metadata": { | ||
"lines_to_next_cell": 2 | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"from utils import levensthein_distance, print_avg_median_mode_error\n", | ||
"from transformers import pipeline, logging\n", | ||
"logging.set_verbosity(logging.ERROR)\n", | ||
"\n", | ||
"\n", | ||
"error_counts = []\n", | ||
"translate = pipeline(\"translation\", model=model, tokenizer=tokenizer, device=device)\n", | ||
"for index in range(len(test)):\n", | ||
" generated = translate(test[index][\"input_text\"], max_length=(dataset_max_len+1)*2)[0][\"translation_text\"]\n", | ||
" error_counts.append(levensthein_distance(generated, test[index][\"output_text\"]))\n", | ||
" if error_counts[-1] > 0:\n", | ||
" print(f\"Example {index}, error count {error_counts[-1]}\")\n", | ||
" print(\"In :\", test[index][\"input_text\"])\n", | ||
" print(\"Gen:\", generated)\n", | ||
" expected = test[index][\"output_text\"]\n", | ||
" print(\"Exp:\", expected)\n", | ||
" else:\n", | ||
" print(f\"Example {index} OK\")\n", | ||
" print(\"-----------------------\")\n", | ||
"\n", | ||
"print_avg_median_mode_error(error_counts)" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"jupytext": { | ||
"cell_metadata_filter": "-all", | ||
"encoding": "# coding: utf-8", | ||
"executable": "/usr/bin/env python", | ||
"main_language": "python", | ||
"notebook_metadata_filter": "-all" | ||
}, | ||
"language_info": { | ||
"name": "python" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 5 | ||
} |