Skip to content

Commit

Permalink
modified: METABRIC.ipynb
Browse files Browse the repository at this point in the history
	modified:   dsm_utilites.py
  • Loading branch information
chiragnagpal committed Apr 2, 2020
1 parent c382bd2 commit 922f4ca
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 243 deletions.
316 changes: 81 additions & 235 deletions METABRIC.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -35,26 +35,36 @@
},
{
"cell_type": "code",
"execution_count": 27,
"execution_count": 43,
"metadata": {},
"outputs": [],
"outputs": [
{
"data": {
"text/plain": [
"[43.68333435058594, 86.86666870117188, 146.33333587646484, 283.5426806640625]"
]
},
"execution_count": 43,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import numpy as np\n",
"\n",
"\n",
"dat1 = df[['x0', 'x1', 'x2', 'x3', 'x4', 'x5', 'x6', 'x7', 'x8']]\n",
"times = (df['duration'].values+1)\n",
"events = df['event'].values\n",
"data = dat1.to_numpy()\n",
"folds = np.array([1]*191 + [2]*191 + [3]*191 + [4]*191 + [5]*190 + [6]*190 + [7]*190 + [8]*190 + [9]*190 + [10]*190 )\n",
"np.random.seed(100)\n",
"folds = np.array([1]*381 + [2]*381 + [3]*381 + [4]*381 + [5]*380 )\n",
"np.random.seed(0)\n",
"np.random.shuffle(folds)\n",
"quantiles = np.quantile(times[events==1], [0.25, .5, .75, .99]).tolist()"
"np.quantile(times[events==1], [0.25, .5, .75, .99]).tolist()"
]
},
{
"cell_type": "code",
"execution_count": 28,
"execution_count": 44,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -65,7 +75,7 @@
},
{
"cell_type": "code",
"execution_count": 29,
"execution_count": 45,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -74,7 +84,7 @@
},
{
"cell_type": "code",
"execution_count": 32,
"execution_count": 74,
"metadata": {},
"outputs": [
{
Expand All @@ -83,7 +93,7 @@
"<module 'dsm_utilites' from '/Users/chiragn/Research/ICML2020/DeepSurvivalMachines/dsm_utilites.py'>"
]
},
"execution_count": 32,
"execution_count": 74,
"metadata": {},
"output_type": "execute_result"
}
Expand All @@ -98,252 +108,88 @@
},
{
"cell_type": "code",
"execution_count": 33,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": false
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n",
"\n",
"\n",
"\n",
"\n",
"\n",
" 0%| | 0/10000 [00:00<?, ?it/s]\u001b[A\u001b[A\u001b[A\u001b[A\u001b[A\u001b[A\n",
"\n",
"\n",
"\n",
"\n",
"\n",
"\n",
" 0%| | 0/10000 [00:00<?, ?it/s]\u001b[A\u001b[A\u001b[A\u001b[A\u001b[A\u001b[A"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"val len: 228\n",
"tr len: 1485\n",
"Censoring in Fold: 0.5723905723905723\n",
"Censoring in Fold: 0.5723905723905723\n",
"tr len: 1295\n",
"Censoring in Fold: 0.5814671814671815\n",
"Censoring in Fold: 0.5814671814671815\n",
"Pretraining the Underlying Distributions...\n"
]
},
{
"name": "stderr",
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "ede4627a9dd342f0ba37184b228b739c",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"\n",
"\n",
"\n",
"\n",
" 0%| | 1/10000 [00:00<36:58, 4.51it/s]\u001b[A\u001b[A\u001b[A\u001b[A\u001b[A\u001b[A\n",
"\n",
"\n",
"\n",
"\n",
"\n",
" 0%| | 2/10000 [00:00<36:55, 4.51it/s]\u001b[A\u001b[A\u001b[A\u001b[A\u001b[A\u001b[A\n",
"\n",
"\n",
"\n",
"\n",
"\n",
" 0%| | 3/10000 [00:00<37:07, 4.49it/s]\u001b[A\u001b[A\u001b[A\u001b[A\u001b[A\u001b[A\n",
"\n",
"\n",
"\n",
"\n",
"\n",
" 0%| | 4/10000 [00:00<37:15, 4.47it/s]\u001b[A\u001b[A\u001b[A\u001b[A\u001b[A\u001b[A\n",
"\n",
"\n",
"\n",
"\n",
"\n",
" 0%| | 5/10000 [00:01<37:15, 4.47it/s]\u001b[A\u001b[A\u001b[A\u001b[A\u001b[A\u001b[A\n",
"\n",
"\n",
"\n",
"\n",
"\n",
" 0%| | 6/10000 [00:01<37:19, 4.46it/s]\u001b[A\u001b[A\u001b[A\u001b[A\u001b[A\u001b[A\n",
"\n",
"\n",
"\n",
"\n",
"\n",
" 0%| | 7/10000 [00:01<37:52, 4.40it/s]\u001b[A\u001b[A\u001b[A\u001b[A\u001b[A\u001b[A\n",
"\n",
"\n",
"\n",
"\n",
"\n",
" 0%| | 8/10000 [00:01<38:33, 4.32it/s]\u001b[A\u001b[A\u001b[A\u001b[A\u001b[A\u001b[A\n",
"\n",
"\n",
"\n",
"\n",
"\n",
" 0%| | 9/10000 [00:02<38:19, 4.34it/s]\u001b[A\u001b[A\u001b[A\u001b[A\u001b[A\u001b[A\n",
"\n",
"\n",
"\n",
"\n",
"\n",
" 0%| | 10/10000 [00:02<37:54, 4.39it/s]\u001b[A\u001b[A\u001b[A\u001b[A\u001b[A\u001b[A\n",
"\n",
"\n",
"\n",
"\n",
"\n",
" 0%| | 11/10000 [00:02<37:52, 4.40it/s]\u001b[A\u001b[A\u001b[A\u001b[A\u001b[A\u001b[A\n",
"\n",
"\n",
"\n",
"\n",
"\n",
" 0%| | 12/10000 [00:02<37:52, 4.40it/s]\u001b[A\u001b[A\u001b[A\u001b[A\u001b[A\u001b[A\n",
"\n",
"\n",
"\n",
"\n",
"\n",
" 0%| | 13/10000 [00:02<37:44, 4.41it/s]\u001b[A\u001b[A\u001b[A\u001b[A\u001b[A\u001b[A\n",
"\n",
"\n",
"\n",
"\n",
"\n",
" 0%| | 14/10000 [00:03<37:50, 4.40it/s]\u001b[A\u001b[A\u001b[A\u001b[A\u001b[A\u001b[A\n",
"\n",
"\n",
"\n",
"\n",
"\n",
" 0%| | 15/10000 [00:03<38:02, 4.38it/s]\u001b[A\u001b[A\u001b[A\u001b[A\u001b[A\u001b[A\n",
"\n",
"\n",
"\n",
"\n",
"\n",
" 0%| | 16/10000 [00:03<38:42, 4.30it/s]\u001b[A\u001b[A\u001b[A\u001b[A\u001b[A\u001b[A\n",
"\n",
"\n",
"\n",
"\n",
"\n",
" 0%| | 17/10000 [00:03<38:20, 4.34it/s]\u001b[A\u001b[A\u001b[A\u001b[A\u001b[A\u001b[A\n",
"\n",
"\n",
"\n",
"\n",
"\n",
" 0%| | 18/10000 [00:04<38:02, 4.37it/s]\u001b[A\u001b[A\u001b[A\u001b[A\u001b[A\u001b[A\n",
"\n",
"\n",
"\n",
"\n",
"\n",
" 0%| | 19/10000 [00:04<38:00, 4.38it/s]\u001b[A\u001b[A\u001b[A\u001b[A\u001b[A\u001b[A\n",
"\n",
"\n",
"\n",
"\n",
"\n",
" 0%| | 20/10000 [00:04<38:09, 4.36it/s]\u001b[A\u001b[A\u001b[A\u001b[A\u001b[A\u001b[A\n",
"\n",
"\n",
"\n",
"\n",
"\n",
" 0%| | 21/10000 [00:04<38:10, 4.36it/s]\u001b[A\u001b[A\u001b[A\u001b[A\u001b[A\u001b[A\n",
"\n",
"\n",
"\n",
"\n",
"\n",
" 0%| | 22/10000 [00:05<38:41, 4.30it/s]\u001b[A\u001b[A\u001b[A\u001b[A\u001b[A\u001b[A\n",
"\n",
"\n",
"\n",
"\n",
"\n",
" 0%| | 23/10000 [00:05<38:25, 4.33it/s]\u001b[A\u001b[A\u001b[A\u001b[A\u001b[A\u001b[A\n",
"\n",
"\n",
"\n",
"\n",
"\n",
" 0%| | 24/10000 [00:05<38:49, 4.28it/s]\u001b[A\u001b[A\u001b[A\u001b[A\u001b[A\u001b[A\n",
"\n",
"\n",
"\n",
"\n",
"\n",
" 0%| | 25/10000 [00:05<38:32, 4.31it/s]\u001b[A\u001b[A\u001b[A\u001b[A\u001b[A\u001b[A\n",
"\n",
"\n",
"\n",
"\n",
"\n",
" 0%| | 26/10000 [00:05<38:16, 4.34it/s]\u001b[A\u001b[A\u001b[A\u001b[A\u001b[A\u001b[A\n",
"\n",
"\n",
"\n",
"\n",
"\n",
" 0%| | 27/10000 [00:06<38:14, 4.35it/s]\u001b[A\u001b[A\u001b[A\u001b[A\u001b[A\u001b[A\n",
"\n",
"\n",
"\n",
"\n",
"\n",
" 0%| | 28/10000 [00:06<38:26, 4.32it/s]\u001b[A\u001b[A\u001b[A\u001b[A\u001b[A\u001b[A\n",
"\n",
"\n",
"\n",
"\n",
"\n",
" 0%| | 29/10000 [00:06<38:26, 4.32it/s]\u001b[A\u001b[A\u001b[A\u001b[A\u001b[A\u001b[A\n",
"\n",
"\n",
"\n",
"\n",
"\n",
" 0%| | 30/10000 [00:06<38:20, 4.33it/s]\u001b[A\u001b[A\u001b[A\u001b[A\u001b[A\u001b[A\n",
"\n",
"\n",
"\n",
"\n",
"\n",
" 0%| | 31/10000 [00:07<37:58, 4.38it/s]\u001b[A\u001b[A\u001b[A\u001b[A\u001b[A\u001b[A"
"202.8469264171793 1.2680954189977467\n"
]
},
{
"ename": "KeyboardInterrupt",
"evalue": "",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-33-b86f30c6fc56>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 75\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdouble\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 76\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 77\u001b[0;31m \u001b[0mdsm_utilites\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrainDSM\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mquantiles\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mx_train\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mt_train\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0me_train\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx_valid\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mt_valid\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0me_valid\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mlr\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mlr\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mbs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mbs\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0malpha\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0malpha\u001b[0m \u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 78\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 79\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/Research/ICML2020/DeepSurvivalMachines/dsm_utilites.py\u001b[0m in \u001b[0;36mtrainDSM\u001b[0;34m(model, quantiles, x_train, t_train, e_train, x_valid, t_valid, e_valid, n_iter, lr, ELBO, lambd, alpha, thres, bs)\u001b[0m\n\u001b[1;32m 159\u001b[0m ELBO=ELBO, lambd=lambd, alpha=alpha)\n\u001b[1;32m 160\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 161\u001b[0;31m \u001b[0mloss\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 162\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 163\u001b[0m \u001b[0moptimizer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstep\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/anaconda2/envs/py375/lib/python3.7/site-packages/torch/tensor.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(self, gradient, retain_graph, create_graph)\u001b[0m\n\u001b[1;32m 193\u001b[0m \u001b[0mproducts\u001b[0m\u001b[0;34m.\u001b[0m \u001b[0mDefaults\u001b[0m \u001b[0mto\u001b[0m\u001b[0;31m \u001b[0m\u001b[0;31m`\u001b[0m\u001b[0;31m`\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;31m`\u001b[0m\u001b[0;31m`\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 194\u001b[0m \"\"\"\n\u001b[0;32m--> 195\u001b[0;31m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mautograd\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgradient\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 196\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 197\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mregister_hook\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhook\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/anaconda2/envs/py375/lib/python3.7/site-packages/torch/autograd/__init__.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables)\u001b[0m\n\u001b[1;32m 97\u001b[0m Variable._execution_engine.run_backward(\n\u001b[1;32m 98\u001b[0m \u001b[0mtensors\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgrad_tensors\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 99\u001b[0;31m allow_unreachable=True) # allow_unreachable flag\n\u001b[0m\u001b[1;32m 100\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 101\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mKeyboardInterrupt\u001b[0m: "
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "39716c9dc58548e5afb3f314a08febe3",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, max=10000.0), HTML(value='')))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"3.5295747584993196 (0.7436527098215582, 0.6811342904030737, 0.5970621901813846, 0.611855947999837)\n",
"3.515131386970598 (0.7386963678694428, 0.6716928435166729, 0.5849140201392863, 0.62534764825785)\n",
"3.5158226769628036 (0.7318455746593371, 0.6699128823791801, 0.5911467483533817, 0.6331405037220301)\n",
"3.515479804334142 (0.7297725821056623, 0.672073403791732, 0.5936461114676398, 0.6369502882527434)\n",
"3.5152090747964597 (0.7299993252087804, 0.6723811713623161, 0.5948084129361593, 0.6397806508689314)\n",
"3.515229667328146 (0.727796010086639, 0.6730755262613353, 0.5972531310617756, 0.6420898707648605)\n",
"3.515459661315446 (0.7282370222892353, 0.6703104281959925, 0.5979750789683245, 0.6429366444808753)\n",
"3.5152011552733837 (0.7284817501550502, 0.6689068376100378, 0.5977576124523662, 0.6441363144872003)\n",
"3.5156224070684705 (0.7298338654512917, 0.6682243802957398, 0.5976506274847705, 0.645576700159252)\n",
"3.516007568690427 (0.7302225119493331, 0.6672430764040999, 0.5983333250469652, 0.6398588975503352)\n",
"3.516402141571075 (0.7294949009807973, 0.6672938918473681, 0.5990394217700785, 0.6353057345001489)\n"
]
}
],
"source": [
"#set parameter grid\n",
"params = [{'G':4, 'mlptyp':1,'HIDDEN':[], 'n_iter':int(500), 'lr':1e-3, 'ELBO':True, 'mean':False, \\\n",
" 'lambd':0, 'alpha':1,'thres':1e-3, 'bs':int(25)}] \n",
"params = [{'G':6, 'mlptyp':2,'HIDDEN':[100], 'n_iter':int(1000), 'lr':1e-3, 'ELBO':True, 'mean':False, \\\n",
" 'lambd':0, 'alpha':1,'thres':1e-3, 'bs':int(25)}]\n",
"\n",
"\n",
"#set val data size\n",
"vsize = int(0.15*1523)\n",
Expand Down Expand Up @@ -417,7 +263,7 @@
" model = dsm.DeepSurvivalMachines(D, K, mlptyp, HIDDEN, dist='Weibull')\n",
" model.double()\n",
" \n",
" dsm_utilites.trainDSM(model,quantiles,x_train, t_train, e_train, x_valid, t_valid, e_valid,lr=lr,bs=bs,alpha=alpha )\n",
" model, i = dsm_utilites.trainDSM(model,quantiles,x_train, t_train, e_train, x_valid, t_valid, e_valid,lr=lr,bs=bs,alpha=alpha )\n",
" \n",
" \n",
" print (\"TEST PERFORMANCE\")\n",
Expand Down
Loading

0 comments on commit 922f4ca

Please sign in to comment.