Skip to content

Commit

Permalink
experiments: make notebook more like the pipeline (#280)
Browse files Browse the repository at this point in the history
* experiments: make notebook more like the pipeline

* update notebook comments
  • Loading branch information
Dave Berenbaum authored Jan 8, 2024
1 parent 21712e9 commit 5be5f1d
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 48 deletions.
84 changes: 38 additions & 46 deletions example-get-started-experiments/code/notebooks/TrainSegModel.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -6,30 +6,11 @@
"metadata": {},
"outputs": [],
"source": [
"from pathlib import Path\n",
"\n",
"ROOT = Path(\"../\")\n",
"DATA = ROOT / \"data\""
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import warnings\n",
"warnings.filterwarnings(\"ignore\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import shutil\n",
"from functools import partial\n",
"from pathlib import Path\n",
"import warnings\n",
"\n",
"import numpy as np\n",
"import torch\n",
Expand All @@ -41,7 +22,10 @@
"from fastai.vision.all import (Resize, SegmentationDataLoaders,\n",
" imagenet_stats, models, unet_learner)\n",
"from ruamel.yaml import YAML\n",
"from PIL import Image"
"from PIL import Image\n",
"\n",
"os.chdir(\"..\")\n",
"warnings.filterwarnings(\"ignore\")"
]
},
{
Expand Down Expand Up @@ -77,14 +61,14 @@
"source": [
"test_regions = [\"REGION_1-\"]\n",
"\n",
"img_fpaths = get_files(DATA / \"pool_data\" / \"images\", extensions=\".jpg\")\n",
"img_fpaths = get_files(Path(\"data\") / \"pool_data\" / \"images\", extensions=\".jpg\")\n",
"\n",
"train_data_dir = DATA / \"train_data\"\n",
"train_data_dir = Path(\"data\") / \"train_data\"\n",
"train_data_dir.mkdir(exist_ok=True)\n",
"test_data_dir = DATA / \"test_data\"\n",
"test_data_dir = Path(\"data\") / \"test_data\"\n",
"test_data_dir.mkdir(exist_ok=True)\n",
"for img_path in img_fpaths:\n",
" msk_path = DATA / \"pool_data\" / \"masks\" / f\"{img_path.stem}.png\"\n",
" msk_path = Path(\"data\") / \"pool_data\" / \"masks\" / f\"{img_path.stem}.png\"\n",
" if any(region in str(img_path) for region in test_regions):\n",
" shutil.copy(img_path, test_data_dir)\n",
" shutil.copy(msk_path, test_data_dir)\n",
Expand Down Expand Up @@ -182,44 +166,42 @@
" dice_list.append(dice)\n",
" return np.mean(dice_list)\n",
"\n",
"\n",
"def evaluate(learn):\n",
" test_img_fpaths = sorted(get_files(DATA / \"test_data\", extensions=\".jpg\"))\n",
" test_img_fpaths = sorted(get_files(Path(\"data\") / \"test_data\", extensions=\".jpg\"))\n",
" test_dl = learn.dls.test_dl(test_img_fpaths)\n",
" preds, _ = learn.get_preds(dl=test_dl)\n",
" masks_pred = np.array(preds[:, 1, :] > 0.5, dtype=np.uint8)\n",
" test_mask_fpaths = [\n",
" get_mask_path(fpath, DATA / \"test_data\") for fpath in test_img_fpaths\n",
" get_mask_path(fpath, Path(\"data\") / \"test_data\") for fpath in test_img_fpaths\n",
" ]\n",
" masks_true = [Image.open(mask_path) for mask_path in test_mask_fpaths]\n",
"\n",
" dice_multi = 0.0\n",
" for ii in range(len(masks_true)):\n",
" mask_pred, mask_true = masks_pred[ii], masks_true[ii]\n",
" width, height = mask_true.shape[1], mask_true.shape[0]\n",
" mask_pred = np.array(\n",
" Image.fromarray(mask_pred).resize((width, height)),\n",
" Image.fromarray(mask_pred).resize((mask_true.shape[1], mask_true.shape[0])),\n",
" dtype=int\n",
" )\n",
" mask_true = np.array(mask_true, dtype=int)\n",
" dice_multi += dice(mask_true, mask_pred) / len(masks_true)\n",
"\n",
" return dice_multi"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": false
},
"metadata": {},
"outputs": [],
"source": [
"train_arch = 'shufflenet_v2_x2_0'\n",
"models_dir = ROOT / \"models\"\n",
"models_dir.mkdir(exist_ok=True)\n",
"results_dir = ROOT / \"results\" / \"train\"\n",
"\n",
"for base_lr in [0.001, 0.005, 0.01]:\n",
" # initialize dvclive, optionally provide output path, and save results as a dvc experiment\n",
" with Live(str(results_dir), save_dvc_exp=True, report=\"notebook\") as live:\n",
" # initialize dvclive, optionally provide output path, and show report in notebook\n",
" # don't save dvc experiment until post-training metrics below\n",
" with Live(\"results/train\", report=\"notebook\", save_dvc_exp=False) as live:\n",
" # log a parameter\n",
" live.log_param(\"train_arch\", train_arch)\n",
" fine_tune_args = {\n",
Expand All @@ -237,19 +219,22 @@
" **fine_tune_args,\n",
" cbs=[DVCLiveCallback(live=live)])\n",
"\n",
" learn.export(fname=(models_dir / \"model.pkl\").absolute())\n",
"\n",
" # add additional post-training summary metrics\n",
" live.summary[\"evaluate/dice_multi\"] = evaluate(learn)\n",
"\n",
" # save model artifact to dvc\n",
" models_dir = Path(\"models\")\n",
" models_dir.mkdir(exist_ok=True)\n",
" learn.export(fname=(models_dir / \"model.pkl\").absolute())\n",
" torch.save(learn.model, (models_dir / \"model.pth\").absolute())\n",
" live.log_artifact(\n",
" str(models_dir / \"model.pkl\"),\n",
" type=\"model\",\n",
" name=\"pool-segmentation\",\n",
" desc=\"This is a Computer Vision (CV) model that's segmenting out swimming pools from satellite images.\",\n",
" labels=[\"cv\", \"segmentation\", \"satellite-images\", \"unet\"],\n",
" )\n"
" )\n",
"\n",
" # add additional post-training summary metrics.\n",
" with Live(\"results/evaluate\") as live:\n",
" live.summary[\"dice_multi\"] = evaluate(learn)"
]
},
{
Expand Down Expand Up @@ -280,6 +265,13 @@
"source": [
"learn.show_results(max_n=6, alpha=0.7)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
Expand All @@ -298,7 +290,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.6"
"version": "3.11.6"
},
"vscode": {
"interpreter": {
Expand Down
4 changes: 2 additions & 2 deletions example-get-started-experiments/generate.sh
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ pip install -r requirements.txt --extra-index-url https://download.pytorch.org/w
pip install jupyter
jupyter nbconvert --execute 'notebooks/TrainSegModel.ipynb' --inplace
# Apply best experiment
BEST_EXP_ROW=$(dvc exp show --drop '.*' --keep 'Experiment|evaluate/dice_multi|base_lr' --csv --sort-by evaluate/dice_multi | tail -n 1)
BEST_EXP_ROW=$(dvc exp show --drop '.*' --keep 'Experiment|results/evaluate/metrics.json:dice_multi|base_lr' --csv --sort-by 'results/evaluate/metrics.json:dice_multi' | tail -n 1)
BEST_EXP_NAME=$(echo $BEST_EXP_ROW | cut -d, -f 1)
BEST_EXP_BASE_LR=$(echo $BEST_EXP_ROW | cut -d, -f 3)
dvc exp apply $BEST_EXP_NAME
Expand All @@ -98,7 +98,7 @@ cp $HERE/code/params.yaml .
sed -e "s/base_lr: 0.01/base_lr: $BEST_EXP_BASE_LR/" -i".bkp" params.yaml
rm params.yaml.bkp

git rm -r --cached 'results'
git rm -r --cached 'results' 'models'
git commit -m "stop tracking results"

dvc stage add -n data_split \
Expand Down

0 comments on commit 5be5f1d

Please sign in to comment.