Skip to content

Commit

Permalink
feat: add lora fine tuning for llama 3.2
Browse files Browse the repository at this point in the history
  • Loading branch information
jfrery committed Dec 10, 2024
1 parent 1ccaceb commit f5843b3
Show file tree
Hide file tree
Showing 15 changed files with 7,129 additions and 833 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/refresh-one-notebook.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ on:
- KNearestNeighbors \n
- LinearRegression \n
- LinearSVR \n
- LLamaFineTuning \n
- LogisticRegression \n
- LogisticRegressionTraining \n
- LoraMLP \n
Expand Down Expand Up @@ -76,6 +77,7 @@ env:
KNearestNeighbors: "docs/advanced_examples/KNearestNeighbors.ipynb"
LinearRegression: "docs/advanced_examples/LinearRegression.ipynb"
LinearSVR: "docs/advanced_examples/LinearSVR.ipynb"
LLamaFineTuning: "use_case_examples/lora_finetuning/LLamaFineTuning.ipynb"
LogisticRegression: "docs/advanced_examples/LogisticRegression.ipynb"
LogisticRegressionTraining: "docs/advanced_examples/LogisticRegressionTraining.ipynb"
LoraMLP: "docs/advanced_examples/LoraMLP.ipynb"
Expand Down
280 changes: 23 additions & 257 deletions docs/advanced_examples/LoraMLP.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
{
"data": {
"text/plain": [
"<torch._C.Generator at 0x7ffa268f2530>"
"<torch._C.Generator at 0x7d9bca770290>"
]
},
"execution_count": 1,
Expand All @@ -31,7 +31,6 @@
],
"source": [
"import shutil\n",
"import time\n",
"from pathlib import Path\n",
"\n",
"import matplotlib.pyplot as plt\n",
Expand All @@ -41,10 +40,8 @@
"from sklearn.datasets import make_circles, make_moons\n",
"from torch import nn, optim\n",
"from torch.utils.data import DataLoader, TensorDataset\n",
"from tqdm import tqdm\n",
"\n",
"from concrete.ml.torch.hybrid_model import HybridFHEModel\n",
"from concrete.ml.torch.lora import LoraTraining, get_remote_names\n",
"from concrete.ml.torch.lora import LoraTrainer\n",
"\n",
"# Set random seed for reproducibility\n",
"SEED = 42\n",
Expand Down Expand Up @@ -132,13 +129,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"Training on Task 1 without LoRA:\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Training on Task 1 without LoRA:\n",
"Epoch [20/20], Loss: 0.0036\n"
]
},
Expand Down Expand Up @@ -276,224 +267,47 @@
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"LoRA layers detected in the model.\n"
]
}
],
"source": [
"# Set up LoRA training\n",
"lora_training = LoraTraining(peft_model)\n",
"\n",
"# Set up optimizer and scheduler\n",
"# Update training parameters, including loss function\n",
"optimizer = optim.Adam(filter(lambda p: p.requires_grad, peft_model.parameters()), lr=0.01)\n",
"loss_fn = nn.CrossEntropyLoss()\n",
"training_args = {\"gradient_accumulation_steps\": 1}\n",
"\n",
"# Update training parameters, including loss function\n",
"lora_training.update_training_parameters(\n",
" optimizer=optimizer,\n",
" loss_fn=nn.CrossEntropyLoss(),\n",
" training_args={\"gradient_accumulation_steps\": 1},\n",
"# Set up LoRA training\n",
"lora_trainer = LoraTrainer(\n",
" peft_model, optimizer=optimizer, loss_fn=loss_fn, training_args=training_args\n",
")\n",
"\n",
"# Create the HybridFHEModel\n",
"remote_names = get_remote_names(lora_training)\n",
"hybrid_model = HybridFHEModel(lora_training, module_names=remote_names)\n",
"\n",
"# Prepare input data for calibration\n",
"batch_size_per_task = batch_size // 2\n",
"inputset = (\n",
" torch.cat([X_task1[:batch_size_per_task], X_task2[:batch_size_per_task]]),\n",
" torch.cat([y_task1[:batch_size_per_task], y_task2[:batch_size_per_task]]),\n",
")\n",
"\n",
"# Calibrate and compile the model\n",
"lora_training.toggle_calibrate(enable=True)\n",
"hybrid_model.compile_model(inputset, n_bits=8)\n",
"lora_training.toggle_calibrate(enable=False)"
"# Compile the model\n",
"lora_trainer.compile(inputset, n_bits=8)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Fine-tuning on Task 2 with LoRA:\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\r",
"Training: 0%| | 0/10 [00:00<?, ?epoch/s]"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\r",
"Training: 0%| | 0/10 [00:34<?, ?epoch/s, Epoch=1, Avg Loss=2.3775, Time=34.38s, FHE Mode=execute]"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\r",
"Training: 10%|█ | 1/10 [00:34<05:09, 34.38s/epoch, Epoch=1, Avg Loss=2.3775, Time=34.38s, FHE Mode=execute]"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\r",
"Training: 10%|█ | 1/10 [01:07<05:09, 34.38s/epoch, Epoch=2, Avg Loss=1.6292, Time=32.99s, FHE Mode=execute]"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\r",
"Training: 20%|██ | 2/10 [01:07<04:28, 33.56s/epoch, Epoch=2, Avg Loss=1.6292, Time=32.99s, FHE Mode=execute]"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\r",
"Training: 20%|██ | 2/10 [01:39<04:28, 33.56s/epoch, Epoch=3, Avg Loss=0.8214, Time=31.86s, FHE Mode=execute]"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\r",
"Training: 30%|███ | 3/10 [01:39<03:49, 32.79s/epoch, Epoch=3, Avg Loss=0.8214, Time=31.86s, FHE Mode=execute]"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\r",
"Training: 30%|███ | 3/10 [02:10<03:49, 32.79s/epoch, Epoch=4, Avg Loss=0.5415, Time=31.45s, FHE Mode=execute]"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\r",
"Training: 40%|████ | 4/10 [02:10<03:13, 32.26s/epoch, Epoch=4, Avg Loss=0.5415, Time=31.45s, FHE Mode=execute]"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\r",
"Training: 40%|████ | 4/10 [02:42<03:13, 32.26s/epoch, Epoch=5, Avg Loss=0.3884, Time=31.78s, FHE Mode=execute]"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\r",
"Training: 50%|█████ | 5/10 [02:42<02:40, 32.09s/epoch, Epoch=5, Avg Loss=0.3884, Time=31.78s, FHE Mode=execute]"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\r",
"Training: 50%|█████ | 5/10 [03:14<02:40, 32.09s/epoch, Epoch=6, Avg Loss=0.3246, Time=32.02s, FHE Mode=execute]"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\r",
"Training: 60%|██████ | 6/10 [03:14<02:08, 32.07s/epoch, Epoch=6, Avg Loss=0.3246, Time=32.02s, FHE Mode=execute]"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\r",
"Training: 60%|██████ | 6/10 [03:45<02:08, 32.07s/epoch, Epoch=7, Avg Loss=0.3145, Time=31.47s, FHE Mode=execute]"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\r",
"Training: 70%|███████ | 7/10 [03:45<01:35, 31.87s/epoch, Epoch=7, Avg Loss=0.3145, Time=31.47s, FHE Mode=execute]"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\r",
"Training: 70%|███████ | 7/10 [04:17<01:35, 31.87s/epoch, Epoch=8, Avg Loss=0.2942, Time=31.38s, FHE Mode=execute]"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\r",
"Training: 80%|████████ | 8/10 [04:17<01:03, 31.72s/epoch, Epoch=8, Avg Loss=0.2942, Time=31.38s, FHE Mode=execute]"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\r",
"Training: 80%|████████ | 8/10 [04:49<01:03, 31.72s/epoch, Epoch=9, Avg Loss=0.2913, Time=31.65s, FHE Mode=execute]"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\r",
"Training: 90%|█████████ | 9/10 [04:49<00:31, 31.70s/epoch, Epoch=9, Avg Loss=0.2913, Time=31.65s, FHE Mode=execute]"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\r",
"Training: 90%|█████████ | 9/10 [05:20<00:31, 31.70s/epoch, Epoch=10, Avg Loss=0.2978, Time=31.63s, FHE Mode=execute]"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\r",
"Training: 100%|██████████| 10/10 [05:20<00:00, 31.68s/epoch, Epoch=10, Avg Loss=0.2978, Time=31.63s, FHE Mode=execute]"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\r",
"Training: 100%|██████████| 10/10 [05:20<00:00, 32.06s/epoch, Epoch=10, Avg Loss=0.2978, Time=31.63s, FHE Mode=execute]"
"Training: 100%|██████████| 10/10 [04:42<00:00, 28.28s/epoch, Epoch=10, Avg Loss=0.2978, FHE Mode=execute]"
]
},
{
Expand All @@ -513,55 +327,7 @@
],
"source": [
"# Fine-tune the model on Task 2 using LoRA\n",
"\n",
"\n",
"def train_hybrid_model(hybrid_model, train_loader, num_epochs=50, fhe=\"disable\"):\n",
" \"\"\"Train the model using the hybrid FHE model with gradient accumulation.\n",
"\n",
" Args:\n",
" hybrid_model (HybridFHEModel): The compiled hybrid model.\n",
" train_loader (DataLoader): DataLoader for training data.\n",
" num_epochs (int): Number of epochs to train.\n",
" fhe (str): FHE mode ('disable', 'simulate', or 'execute').\n",
" \"\"\"\n",
" device = torch.device(\"cpu\")\n",
" lora_training.to(device)\n",
" peft_model.train()\n",
" lora_training.toggle_run_optimizer(enable=True)\n",
"\n",
" # Create the main epoch progress bar\n",
" epoch_pbar = tqdm(range(1, num_epochs + 1), desc=\"Training\", unit=\"epoch\")\n",
"\n",
" for epoch in epoch_pbar:\n",
" total_loss = 0\n",
" start_time = time.time()\n",
"\n",
" for x_batch, y_batch in train_loader:\n",
" x_batch = x_batch.to(device)\n",
" y_batch = y_batch.to(device)\n",
"\n",
" loss, _ = hybrid_model((x_batch, y_batch), fhe=fhe)\n",
" total_loss += loss.item()\n",
"\n",
" # Calculate average loss and epoch time\n",
" avg_loss = total_loss / len(train_loader)\n",
" epoch_time = time.time() - start_time\n",
"\n",
" # Update epoch progress bar\n",
" epoch_pbar.set_postfix(\n",
" {\n",
" \"Epoch\": epoch,\n",
" \"Avg Loss\": f\"{avg_loss:.4f}\",\n",
" \"Time\": f\"{epoch_time:.2f}s\",\n",
" \"FHE Mode\": fhe,\n",
" }\n",
" )\n",
"\n",
" print(f\"Training completed. Final Avg Loss: {avg_loss:.4f}, FHE Mode: {fhe}\")\n",
"\n",
"\n",
"print(\"Fine-tuning on Task 2 with LoRA:\")\n",
"train_hybrid_model(hybrid_model, train_loader_task2, num_epochs=10, fhe=\"execute\")"
"lora_trainer.train(train_loader_task2, num_epochs=10, fhe=\"execute\")"
]
},
{
Expand Down Expand Up @@ -674,7 +440,7 @@
"if path.is_dir() and any(path.iterdir()):\n",
" shutil.rmtree(path)\n",
"\n",
"hybrid_model.save_and_clear_private_info(path)\n",
"lora_trainer.save_and_clear_private_info(path)\n",
"\n",
"# At this point, the hybrid_model only contains the trainable parameters of the LoRA layers.\n",
"peft_model.print_trainable_parameters()"
Expand Down
Loading

0 comments on commit f5843b3

Please sign in to comment.