From 3a9d013f136cf46040d21730e99a080f8597fc0a Mon Sep 17 00:00:00 2001 From: Chirag Nagpal Date: Thu, 2 Apr 2020 14:06:08 -0400 Subject: [PATCH] modified: METABRIC.ipynb modified: dsm_loss.py modified: dsm_utilites.py --- METABRIC.ipynb | 315 ++++++++++++++++++++++++++++++++++++++++-------- dsm_loss.py | 2 +- dsm_utilites.py | 8 +- 3 files changed, 269 insertions(+), 56 deletions(-) diff --git a/METABRIC.ipynb b/METABRIC.ipynb index e1a2693..a45a962 100644 --- a/METABRIC.ipynb +++ b/METABRIC.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": 25, + "execution_count": 95, "metadata": {}, "outputs": [], "source": [ @@ -18,7 +18,7 @@ }, { "cell_type": "code", - "execution_count": 26, + "execution_count": 96, "metadata": {}, "outputs": [], "source": [ @@ -35,23 +35,13 @@ }, { "cell_type": "code", - "execution_count": 43, + "execution_count": 123, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "[43.68333435058594, 86.86666870117188, 146.33333587646484, 283.5426806640625]" - ] - }, - "execution_count": 43, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "import numpy as np\n", "\n", + "\n", "dat1 = df[['x0', 'x1', 'x2', 'x3', 'x4', 'x5', 'x6', 'x7', 'x8']]\n", "times = (df['duration'].values+1)\n", "events = df['event'].values\n", @@ -59,12 +49,12 @@ "folds = np.array([1]*381 + [2]*381 + [3]*381 + [4]*381 + [5]*380 )\n", "np.random.seed(0)\n", "np.random.shuffle(folds)\n", - "np.quantile(times[events==1], [0.25, .5, .75, .99]).tolist()" + "quantiles = np.quantile(times[events==1], [0.25, .5, .75, .99]).tolist()" ] }, { "cell_type": "code", - "execution_count": 44, + "execution_count": 124, "metadata": {}, "outputs": [], "source": [ @@ -75,7 +65,7 @@ }, { "cell_type": "code", - "execution_count": 45, + "execution_count": 125, "metadata": {}, "outputs": [], "source": [ @@ -84,7 +74,7 @@ }, { "cell_type": "code", - "execution_count": 74, + "execution_count": 126, "metadata": {}, "outputs": [ { @@ -93,7 +83,7 @@ "" ] }, - "execution_count": 74, + "execution_count": 126, "metadata": {}, "output_type": "execute_result" } @@ -108,38 +98,71 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 127, "metadata": {}, - "outputs": [], - "source": [] + "outputs": [ + { + "data": { + "text/plain": [ + "1713.6" + ] + }, + "execution_count": 127, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "float(len(dat1)*9)/10" + ] }, { "cell_type": "code", - "execution_count": null, - "metadata": { - "scrolled": false - }, + "execution_count": 128, + "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "val len: 228\n", - "tr len: 1295\n", - "Censoring in Fold: 0.5814671814671815\n", - "Censoring in Fold: 0.5814671814671815\n", + "val len: 256\n", + "tr len: 1267\n", + "Censoring in Fold: 0.585635359116022\n", + "Censoring in Fold: 0.585635359116022\n", + "Weibull\n", "Pretraining the Underlying Distributions...\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "ede4627a9dd342f0ba37184b228b739c", + "model_id": "94a08c4f8cb9413cb5550461cef08843", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(FloatProgress(value=0.0, max=10000.0), HTML(value='')))" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "200.75226152672906 1.2680803184752372\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "5f9ed6a9d53147a8976a505baa8d8827", "version_major": 2, "version_minor": 0 }, "text/plain": [ - "HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))" + "HBox(children=(FloatProgress(value=0.0, max=10000.0), HTML(value='')))" ] }, "metadata": {}, @@ -150,13 +173,139 @@ "output_type": "stream", "text": [ "\n", - "202.8469264171793 1.2680954189977467\n" + "TEST PERFORMANCE\n", + "(0.7154317427767943, 0.6613066844703644, 0.6095326277314909, 0.5646184019234619)\n", + "val len: 256\n", + "tr len: 1267\n", + "Censoring in Fold: 0.5864246250986582\n", + "Censoring in Fold: 0.5864246250986582\n", + "Weibull\n", + "Pretraining the Underlying Distributions...\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "39e682733b7c49a193c5d9db1fc25192", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(FloatProgress(value=0.0, max=10000.0), HTML(value='')))" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "199.7153561845356 1.2632128253408417\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "fd9d3c6b464944b39af18518b6750d37", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(FloatProgress(value=0.0, max=10000.0), HTML(value='')))" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "TEST PERFORMANCE\n", + "(0.6007865597948178, 0.5885119373707697, 0.586134286427175, 0.601290434891957)\n", + "val len: 256\n", + "tr len: 1267\n", + "Censoring in Fold: 0.579321231254933\n", + "Censoring in Fold: 0.579321231254933\n", + "Weibull\n", + "Pretraining the Underlying Distributions...\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "70d52a9cada6477597b7c145f7903f92", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(FloatProgress(value=0.0, max=10000.0), HTML(value='')))" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "203.71121598683038 1.2771277338630311\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "73cbf5108ace4da587dc1f0dde8d5040", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(FloatProgress(value=0.0, max=10000.0), HTML(value='')))" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "TEST PERFORMANCE\n", + "(0.7121724313147351, 0.6818384881157109, 0.6577998627814042, 0.6497338850905421)\n", + "val len: 256\n", + "tr len: 1267\n", + "Censoring in Fold: 0.5785319652722968\n", + "Censoring in Fold: 0.5785319652722968\n", + "Weibull\n", + "Pretraining the Underlying Distributions...\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "c293ce1812334109a469224d956c51c6", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(FloatProgress(value=0.0, max=10000.0), HTML(value='')))" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "200.38728541667388 1.2916873727801628\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "39716c9dc58548e5afb3f314a08febe3", + "model_id": "5ef0390a60364d89a9d697f5b4f11147", "version_major": 2, "version_minor": 0 }, @@ -171,28 +320,69 @@ "name": "stdout", "output_type": "stream", "text": [ - "3.5295747584993196 (0.7436527098215582, 0.6811342904030737, 0.5970621901813846, 0.611855947999837)\n", - "3.515131386970598 (0.7386963678694428, 0.6716928435166729, 0.5849140201392863, 0.62534764825785)\n", - "3.5158226769628036 (0.7318455746593371, 0.6699128823791801, 0.5911467483533817, 0.6331405037220301)\n", - "3.515479804334142 (0.7297725821056623, 0.672073403791732, 0.5936461114676398, 0.6369502882527434)\n", - "3.5152090747964597 (0.7299993252087804, 0.6723811713623161, 0.5948084129361593, 0.6397806508689314)\n", - "3.515229667328146 (0.727796010086639, 0.6730755262613353, 0.5972531310617756, 0.6420898707648605)\n", - "3.515459661315446 (0.7282370222892353, 0.6703104281959925, 0.5979750789683245, 0.6429366444808753)\n", - "3.5152011552733837 (0.7284817501550502, 0.6689068376100378, 0.5977576124523662, 0.6441363144872003)\n", - "3.5156224070684705 (0.7298338654512917, 0.6682243802957398, 0.5976506274847705, 0.645576700159252)\n", - "3.516007568690427 (0.7302225119493331, 0.6672430764040999, 0.5983333250469652, 0.6398588975503352)\n", - "3.516402141571075 (0.7294949009807973, 0.6672938918473681, 0.5990394217700785, 0.6353057345001489)\n" + "TEST PERFORMANCE\n", + "(0.7604568637757466, 0.7232963903416784, 0.6617261376642131, 0.646405712483512)\n", + "val len: 256\n", + "tr len: 1268\n", + "Censoring in Fold: 0.582807570977918\n", + "Censoring in Fold: 0.582807570977918\n", + "Weibull\n", + "Pretraining the Underlying Distributions...\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "9532b513edea4391ba35b760fbcfee4b", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(FloatProgress(value=0.0, max=10000.0), HTML(value='')))" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "200.8431094911515 1.2591288578083222\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "7194498ec1da4372800fc828c409ab72", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(FloatProgress(value=0.0, max=10000.0), HTML(value='')))" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "TEST PERFORMANCE\n", + "(0.7380987174602022, 0.6644269287843599, 0.6575335218015144, 0.6570588366776975)\n" ] } ], "source": [ "#set parameter grid\n", + "\n", "params = [{'G':6, 'mlptyp':2,'HIDDEN':[100], 'n_iter':int(1000), 'lr':1e-3, 'ELBO':True, 'mean':False, \\\n", - " 'lambd':0, 'alpha':1,'thres':1e-3, 'bs':int(25)}]\n", + " 'lambd':0, 'alpha':1,'thres':1e-3, 'bs':int(25), 'dist': 'Weibull'}]\n", "\n", "\n", "#set val data size\n", - "vsize = int(0.15*1523)\n", + "vsize = int(0.15*1712)\n", "\n", "torch.manual_seed(0)\n", "\n", @@ -254,13 +444,16 @@ " t_test = torch.from_numpy(t_test).double() \n", "\n", "\n", - " K, mlptyp, HIDDEN, n_iter, lr, ELBO, mean, lambd, alpha, thres, bs = \\\n", + " K, mlptyp, HIDDEN, n_iter, lr, ELBO, mean, lambd, alpha, thres, bs,dist = \\\n", " param['G'], param['mlptyp'], param['HIDDEN'], param['n_iter'], param['lr'], \\\n", - " param['ELBO'], param['mean'], param['lambd'], param['alpha'], param['thres'], param['bs']\n", + " param['ELBO'], param['mean'], param['lambd'], param['alpha'], param['thres'],\\\n", + " param['bs'], param['dist'] \n", "\n", " D = x_train.shape[1]\n", " \n", - " model = dsm.DeepSurvivalMachines(D, K, mlptyp, HIDDEN, dist='Weibull')\n", + " print (dist)\n", + " \n", + " model = dsm.DeepSurvivalMachines(D, K, mlptyp, HIDDEN, dist=dist)\n", " model.double()\n", " \n", " model, i = dsm_utilites.trainDSM(model,quantiles,x_train, t_train, e_train, x_valid, t_valid, e_valid,lr=lr,bs=bs,alpha=alpha )\n", @@ -295,6 +488,26 @@ "quantiles" ] }, + { + "cell_type": "code", + "execution_count": 93, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'LogNormal'" + ] + }, + "execution_count": 93, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model.dist" + ] + }, { "cell_type": "code", "execution_count": null, diff --git a/dsm_loss.py b/dsm_loss.py index 14aa331..60f192c 100644 --- a/dsm_loss.py +++ b/dsm_loss.py @@ -3,7 +3,7 @@ def _logNormalLoss(model, x, t, e): import numpy as np - shape, scale, logits = model.forward(x) + shape, scale, logits = model.forward(x, adj=False) k_ = shape.expand(x.shape[0], -1) b_ = scale.expand(x.shape[0], -1) diff --git a/dsm_utilites.py b/dsm_utilites.py index e1b8970..96420c9 100644 --- a/dsm_utilites.py +++ b/dsm_utilites.py @@ -54,7 +54,7 @@ def increaseCensoring(e, t, p): def pretrainDSM(model, x_train, t_train, e_train, x_valid, t_valid, e_valid, \ - n_iter=1000, lr=1e-2, thres=1e-4): + n_iter=10000, lr=1e-2, thres=1e-4): from tqdm import tqdm_notebook as tqdm @@ -125,12 +125,12 @@ def trainDSM(model,quantiles , x_train, t_train, e_train, x_valid, t_valid, e_va G = model.k mlptyp = model.mlptype HIDDEN = model.HIDDEN - + dist = model.dist print ("Pretraining the Underlying Distributions...") premodel = pretrainDSM(model, x_train, t_train, e_train, x_valid, t_valid, e_valid, \ - n_iter=1000, lr=1e-2, thres=1e-4) + n_iter=10000, lr=1e-2, thres=1e-4) print(torch.exp(-premodel.scale).cpu().data.numpy()[0], \ @@ -139,7 +139,7 @@ def trainDSM(model,quantiles , x_train, t_train, e_train, x_valid, t_valid, e_va model = DeepSurvivalMachines(x_train.shape[1], G, mlptyp=mlptyp, HIDDEN=HIDDEN, \ - init=(float(premodel.shape[0]), float(premodel.scale[0]) )) + init=(float(premodel.shape[0]), float(premodel.scale[0]) ), dist=dist ) model.double()