diff --git a/dsm/__init__.py b/dsm/__init__.py index 07faa6b..9c972fc 100644 --- a/dsm/__init__.py +++ b/dsm/__init__.py @@ -181,4 +181,4 @@ """ -from dsm.dsm_api import DeepSurvivalMachines, DeepRecurrentSurvivalMachines +from dsm.dsm_api import DeepSurvivalMachines, DeepRecurrentSurvivalMachines, DeepConvolutionalSurvivalMachines diff --git a/dsm/datasets.py b/dsm/datasets.py index 1e53e1e..60f531c 100644 --- a/dsm/datasets.py +++ b/dsm/datasets.py @@ -36,6 +36,8 @@ from sklearn.impute import SimpleImputer from sklearn.preprocessing import StandardScaler +import torchvision + def increase_censoring(e, t, p): uncens = np.where(e == 1)[0] @@ -192,6 +194,35 @@ def _load_support_dataset(): remove = ~np.isnan(t) return x[remove], t[remove], e[remove] +def _load_mnist(): + """Helper function to load and preprocess the MNIST dataset. + + The MNIST database of handwritten digits, available from this page, has a + training set of 60,000 examples, and a test set of 10,000 examples. + It is a good database for people who want to try learning techniques and + pattern recognition methods on real-world data while spending minimal + efforts on preprocessing and formatting [1]. + + Please refer to http://yann.lecun.com/exdb/mnist/. + for the original datasource. + + References + ---------- + [1]: LeCun, Y. (1998). The MNIST database of handwritten digits. + http://yann.lecun.com/exdb/mnist/. + + """ + + + train = torchvision.datasets.MNIST(root='datasets/', + train=True, download=True) + x = train.data.numpy() + x = np.expand_dims(x, 1).astype(float) + t = train.targets.numpy().astype(float) + 1 + + e, t = increase_censoring(np.ones(t.shape), t, p=.5) + + return x, t, e def load_dataset(dataset='SUPPORT', **kwargs): """Helper function to load datasets to test Survival Analysis models. @@ -249,5 +280,7 @@ def load_dataset(dataset='SUPPORT', **kwargs): return _load_pbc_dataset(sequential) if dataset == 'FRAMINGHAM': return _load_framingham_dataset(sequential) + if dataset == 'MNIST': + return _load_mnist() else: raise NotImplementedError('Dataset '+dataset+' not implemented.') diff --git a/dsm/dsm_api.py b/dsm/dsm_api.py index 4ccd340..b54bd8a 100644 --- a/dsm/dsm_api.py +++ b/dsm/dsm_api.py @@ -28,7 +28,7 @@ """ from dsm.dsm_torch import DeepSurvivalMachinesTorch -from dsm.dsm_torch import DeepRecurrentSurvivalMachinesTorch +from dsm.dsm_torch import DeepRecurrentSurvivalMachinesTorch, DeepConvolutionalSurvivalMachinesTorch from dsm.losses import predict_cdf import dsm.losses as losses from dsm.utilities import train_dsm, _get_padded_features, _get_padded_targets @@ -66,7 +66,8 @@ def _gen_torch_model(self, inputdim, optimizer, risks): def fit(self, x, t, e, vsize=0.15, iters=1, learning_rate=1e-3, batch_size=100, - elbo=True, optimizer="Adam", random_state=100): + elbo=True, optimizer="Adam", random_state=100, + cuda=False): r"""This method is used to train an instance of the DSM model. @@ -185,7 +186,7 @@ def predict_risk(self, x, t, risk=1): "before calling `predict_risk`.") - def predict_survival(self, x, t, risk=1): + def predict_survival(self, x, t, risk=1, cuda=False): r"""Returns the estimated survival probability at time \( t \), \( \widehat{\mathbb{P}}(T > t|X) \) for some input data \( x \). @@ -327,6 +328,31 @@ def _prepocess_training_data(self, x, t, e, vsize, random_state): x_val, t_val, e_val) -class DeepConvolutionalSurvivalMachines(DeepRecurrentSurvivalMachines): - __doc__ = "..warning:: Not Implemented" - pass +class DeepConvolutionalSurvivalMachines(DSMBase): + """The Deep Convolutional Survival Machines model to handle data with + image-based covariates. + + """ + + def __init__(self, k=3, layers=None, hidden=None, + distribution='Weibull', temp=1000., discount=1.0, typ='ConvNet'): + super(DeepConvolutionalSurvivalMachines, self).__init__(k=k, + layers=layers, + distribution=distribution, + temp=temp, + discount=discount) + self.hidden = hidden + self.typ = typ + def _gen_torch_model(self, inputdim, optimizer, risks): + """Helper function to return a torch model.""" + return DeepConvolutionalSurvivalMachinesTorch(inputdim, + k=self.k, + layers=self.layers, + hidden=self.hidden, + dist=self.dist, + temp=self.temp, + discount=self.discount, + optimizer=optimizer, + typ=self.typ, + risks=risks) + diff --git a/dsm/dsm_torch.py b/dsm/dsm_torch.py index 551a875..fd2e2f3 100644 --- a/dsm/dsm_torch.py +++ b/dsm/dsm_torch.py @@ -35,6 +35,8 @@ import torch.nn as nn import torch +import torchvision +import torch.nn.functional as F __pdoc__ = {} @@ -337,3 +339,130 @@ def forward(self, x, risk='1'): def get_shape_scale(self, risk='1'): return(self.shape[risk], self.scale[risk]) + +class DeepConvolutionalSurvivalMachinesTorch(nn.Module): + """A Torch implementation of Deep Convolutional Survival Machines model. + + This is an implementation of Deep Convolutional Survival Machines model + in torch. It inherits from `DeepSurvivalMachinesTorch` and replaces the + input representation learning MLP with an simple convnet, the parameters of the + underlying distributions and the forward function which is called whenever + data is passed to the module. Each of the parameters are nn.Parameters and + torch automatically keeps track and computes gradients for them. + + .. warning:: + Not designed to be used directly. + Please use the API inferface `dsm.dsm_api.DeepConvolutionalSurvivalMachines`!! + + Parameters + ---------- + inputdim: int + Dimensionality of the input features. + k: int + The number of underlying parametric distributions. + layers: int + The number of hidden layers in the LSTM or RNN cell. + hidden: int + The number of neurons in each hidden layer. + init: tuple + A tuple for initialization of the parameters for the underlying + distributions. (shape, scale). + dist: str + Choice of the underlying survival distributions. + One of 'Weibull', 'LogNormal'. + Default is 'Weibull'. + temp: float + The logits for the gate are rescaled with this value. + Default is 1000. + discount: float + a float in [0,1] that determines how to discount the tail bias + from the uncensored instances. + Default is 1. + + """ + + def __init__(self, inputdim, k, typ='ResNet', layers=1, + hidden=None, dist='Weibull', + temp=1000., discount=1.0, optimizer='Adam', risks=1): + super(DeepConvolutionalSurvivalMachinesTorch, self).__init__() + + self.k = k + self.dist = dist + self.temp = float(temp) + self.discount = float(discount) + self.optimizer = optimizer + self.hidden = hidden + self.layers = layers + self.typ = typ + self.risks = risks + + if self.dist in ['Weibull']: + self.act = nn.SELU() + self.shape = nn.ParameterDict({str(r+1): nn.Parameter(-torch.ones(k)) + for r in range(self.risks)}) + self.scale = nn.ParameterDict({str(r+1):nn.Parameter(-torch.ones(k)) + for r in range(self.risks)}) + elif self.dist in ['Normal']: + self.act = nn.Identity() + self.shape = nn.ParameterDict({str(r+1): nn.Parameter(torch.ones(k)) + for r in range(self.risks)}) + self.scale = nn.ParameterDict({str(r+1): nn.Parameter(torch.ones(k)) + for r in range(self.risks)}) + elif self.dist in ['LogNormal']: + self.act = nn.Tanh() + self.shape = nn.ParameterDict({str(r+1): nn.Parameter(torch.ones(k)) + for r in range(self.risks)}) + self.scale = nn.ParameterDict({str(r+1): nn.Parameter(torch.ones(k)) + for r in range(self.risks)}) + else: + raise NotImplementedError('Distribution: '+self.dist+' not implemented'+ + ' yet.') + + self.gate = nn.ModuleDict({str(r+1): nn.Sequential( + nn.Linear(hidden, k, bias=False) + ) for r in range(self.risks)}) + + self.scaleg = nn.ModuleDict({str(r+1): nn.Sequential( + nn.Linear(hidden, k, bias=True) + ) for r in range(self.risks)}) + + self.shapeg = nn.ModuleDict({str(r+1): nn.Sequential( + nn.Linear(hidden, k, bias=True) + ) for r in range(self.risks)}) + + if self.typ == 'ConvNet': +# self.cnn = torchvision.models.resnet18(pretrained=True).float() +# self.cnn.conv1 = torch.nn.Conv1d(1, 64, (7, 7), (2, 2), (3, 3), bias=False) +# self.linear = torch.nn.Linear(1000, hidden) + self.conv1 = nn.Conv2d(1, 6, 3) + self.pool = nn.MaxPool2d(2, 2) + self.conv2 = nn.Conv2d(6, 16, 3) + self.fc1 = nn.Linear(16 * 5 * 5, 120) + self.fc2 = nn.Linear(120, 84) + self.fc3 = nn.Linear(84, hidden) + + + def forward(self, x, risk='1'): + """The forward function that is called when data is passed through DSM. + + Args: + x: + a torch.tensor of the input features. + + """ +# xrep = self.linear(self.cnn(x)) + x = self.pool(F.relu(self.conv1(x))) + x = self.pool(F.relu(self.conv2(x))) + x = x.view(-1, 16 * 5 * 5) + x = F.relu(self.fc1(x)) + x = F.relu(self.fc2(x)) + xrep = self.fc3(x) + + dim = x.shape[0] + return(self.act(self.shapeg[risk](xrep))+self.shape[risk].expand(dim, -1), + self.act(self.scaleg[risk](xrep))+self.scale[risk].expand(dim, -1), + self.gate[risk](xrep)/self.temp) + + def get_shape_scale(self, risk='1'): + return(self.shape[risk], + self.scale[risk]) diff --git a/examples/conv_example.ipynb b/examples/conv_example.ipynb new file mode 100644 index 0000000..b13b330 --- /dev/null +++ b/examples/conv_example.ipynb @@ -0,0 +1,621 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "import importlib\n", + "import os, sys\n", + "sys.path.insert(1, os.path.join(sys.path[0], '..'))\n", + "\n", + "from dsm import datasets, DeepSurvivalMachines, DeepConvolutionalSurvivalMachines\n", + "import numpy as np\n", + "from sksurv.metrics import concordance_index_ipcw, brier_score" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(60000, 1, 28, 28) (60000,) (60000,)\n" + ] + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXAAAAD4CAYAAAD1jb0+AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/d3fzzAAAACXBIWXMAAAsTAAALEwEAmpwYAAAL6klEQVR4nO3dX4ild33H8fenu0pNBP+wQ6i7obMXoiwBGxlKaoqI64UlwfTKRkixqbI3tkaxSPQm9M4LEXNRhGXVWgzaEgOVVqyglbZQQmcTITWrIDF/Nm6aI/Ufgo0h317MkexMdjIz5zyZ5/nOvF83e+aZZ87z4WHnw2+e85zzTVUhSernt8YOIElajAUuSU1Z4JLUlAUuSU1Z4JLU1NH9PNixY8dqdXV1Pw8pSe2dP3/+x1W1snX7vhb46uoq6+vr+3lISWovyWNX2u4lFElqygKXpKYscElqygKXpKYscElqase7UJJ8DrgZeLqqrptvey3w98Aq8Cjw7qr6yUsXU5L6Wb3zn1+w7dFP3DTY8+9mBf63wDu3bLsT+GZVvR745vxrSdLclcr7xbYvYscCr6p/A/53y+ZbgC/MH38B+OPBEkmSdmXRa+DXVNWl+eOngGu22zHJmSTrSdZns9mCh5MkbbX0i5i1MRFi26kQVXW2qtaqam1l5QXvBJUkLWjRAv+fJL8DMP/36eEiSZJ2Y9EC/yrw3vnj9wL/OEwcSToYtrvbZMi7UHZzG+GXgLcBx5JcBO4CPgH8Q5L3AY8B7x4skSQdEEOW9ZXsWOBV9Z5tvnV64CySpD3wnZiS1JQFLklNWeCS1JQFLklNWeCS1JQFLklNWeCS1JQFLklNWeCS1JQFLklNWeCS1JQFLklNWeCS1NSOn0YoSR291BPhp5DBFbikA2c/JsJPIYMFLklNWeCS1JQFLklNWeCS1JQFLunA2Y+J8FPIkKoa7Ml2sra2Vuvr6/t2PEk6CJKcr6q1rdtdgUtSUxa4JDVlgUtSUxa4JDVlgUtSUxa4JDVlgUtSUxa4JDVlgUtSUxa4JDVlgUtSUxa4JDVlgUtSUxa4JDW11FT6JB8G3g8U8BBwe1X9aohgknqawjT4qeSY7FT6JMeBDwJrVXUdcAS4dahgkvqZwjT4qeToMJX+KPCKJEeBq4AfLR9JkrQbCxd4VT0JfBJ4HLgE/KyqvrF1vyRnkqwnWZ/NZosnlSRtsswllNcAtwAngdcBVye5bet+VXW2qtaqam1lZWXxpJKkTZa5hPIO4IdVNauqXwP3AW8ZJpYkaSfLFPjjwA1JrkoS4DRwYZhYkjqawjT4qeSY/FT6JH8N/AnwLPAg8P6q+r/t9ncqvSTt3XZT6Ze6D7yq7gLuWuY5JEmL8Z2YktSUBS5JTVngktSUBS5JTVngktSUBS5JTVngktSUBS5JTVngktSUBS5JTVngktSUBS5JTVngktTUUp9GKGk6pjCF3Rz7m8EVuHQATGEKuzn2P4MFLklNWeCS1JQFLklNWeCS1JQFLh0AU5jCbo79z7DUVPq9ciq9JO3ddlPpXYFLUlMWuCQ1ZYFLUlMWuCQ1ZYFLUlMWuCQ1ZYFLUlMWuCQ1ZYFLUlMWuCQ1ZYFLUlMWuCQ1ZYFLUlMWuCQ1tdRU+iSvBs4B1wEF/HlV/ecAuaQ2pjD93BzTzDH1qfR3A1+vqjcCbwIuLB9J6mMK08/NMc0c+5Fh4RV4klcBbwX+DKCqngGeGSaWJGkny6zATwIz4PNJHkxyLsnVW3dKcibJepL12Wy2xOEkSZdbpsCPAm8GPlNV1wO/BO7culNVna2qtapaW1lZWeJwkqTLLVPgF4GLVXX//Ot72Sh0SdI+WLjAq+op4Ikkb5hvOg08PEgqqYkpTD83xzRzTH4qfZLfY+M2wpcDjwC3V9VPttvfqfSStHfbTaVf6j7wqvoO8IInlSS99HwnpiQ1ZYFLUlMWuCQ1ZYFLUlMWuCQ1ZYFLUlMWuCQ1ZYFLUlMWuCQ1ZYFLUlMWuCQ1ZYFLUlMWuCQ1tdSnEUpjmsLUcXOYY8wMrsDV0hSmjpvDHGNnsMAlqSkLXJKassAlqSkLXJKassDV0hSmjpvDHGNnWGoq/V45lV6S9m67qfSuwCWpKQtckpqywCWpKQtckpqywCWpKQtckpqywCWpKQtckpqywCWpKQtckpqywCWpKQtckpqywCWpKQtckppaeip9kiPAOvBkVd28fCRN3RSmfZvDHB1ydJhKfwdwYYDnUQNTmPZtDnN0yDH5qfRJTgA3AeeGiSNJ2q1lV+CfBj4KPLfdDknOJFlPsj6bzZY8nCTpNxYu8CQ3A09X1fkX26+qzlbVWlWtraysLHo4SdIWy6zAbwTeleRR4MvA25N8cZBUkqQdLVzgVfWxqjpRVavArcC3quq2wZJpkqYw7dsc5uiQo81U+iRvA/5qp9sInUovSXu33VT6pe8DB6iqbwPfHuK5JEm74zsxJakpC1ySmrLAJakpC1ySmrLAJakpC1ySmrLAJakpC1ySmrLAJakpC1ySmrLAJakpC1ySmrLAJampQT6NUPvnMEzaNoc5DkqODlPptU8Oy6Rtc5jjIOSY/FR6SdJ4LHBJasoCl6SmLHBJasoCb+SwTNo2hzkOQo42U+l3y6n0krR3202ldwUuSU1Z4JLUlAUuSU1Z4JLUlAUuSU1Z4JLUlAUuSU1Z4JLUlAUuSU1Z4JLUlAUuSU1Z4JLUlAUuSU1Z4JLU1MJT6ZNcC/wdcA1QwNmqunuoYFMzhQnXU8kxhQzmMEeHHFOeSv8s8JGqOgXcAHwgyalhYk3LFCZcTyXHFDKYwxwdckx6Kn1VXaqqB+aPfwFcAI4PFUyS9OIGuQaeZBW4Hrj/Ct87k2Q9yfpsNhvicJIkBijwJK8EvgJ8qKp+vvX7VXW2qtaqam1lZWXZw0mS5pYq8CQvY6O876mq+4aJJEnajYULPEmAzwIXqupTw0WanilMuJ5KjilkMIc5OuSY9FT6JH8I/DvwEPDcfPPHq+pr2/2MU+klae+2m0q/8H3gVfUfQJZKJUlamO/ElKSmLHBJasoCl6SmLHBJasoCl6SmLHBJasoCl6SmLHBJasoCl6SmLHBJasoCl6SmLHBJasoCl6SmFv40wv0yhcnS5pheBnOYo0OOKU+lf8lNYbK0OaaXwRzm6JBj0lPpJUnjssAlqSkLXJKassAlqalJF/gUJkubY3oZzGGODjkmPZV+EU6ll6S9224q/aRX4JKk7VngktSUBS5JTVngktSUBS5JTe3rXShJZsBjC/74MeDHA8bpzvPxPM/FZp6PzQ7C+fjdqlrZunFfC3wZSdavdBvNYeX5eJ7nYjPPx2YH+Xx4CUWSmrLAJampTgV+duwAE+P5eJ7nYjPPx2YH9ny0uQYuSdqs0wpcknQZC1ySmmpR4EnemeT7SX6Q5M6x84wlybVJ/jXJw0m+m+SOsTNNQZIjSR5M8k9jZxlbklcnuTfJ95JcSPIHY2caS5IPz39P/jvJl5L89tiZhjb5Ak9yBPgb4I+AU8B7kpwaN9VongU+UlWngBuADxzic3G5O4ALY4eYiLuBr1fVG4E3cUjPS5LjwAeBtaq6DjgC3DpuquFNvsCB3wd+UFWPVNUzwJeBW0bONIqqulRVD8wf/4KNX87j46YaV5ITwE3AubGzjC3Jq4C3Ap8FqKpnquqno4Ya11HgFUmOAlcBPxo5z+A6FPhx4InLvr7IIS8tgCSrwPXA/SNHGdungY8Cz42cYwpOAjPg8/NLSueSXD12qDFU1ZPAJ4HHgUvAz6rqG+OmGl6HAtcWSV4JfAX4UFX9fOw8Y0lyM/B0VZ0fO8tEHAXeDHymqq4HfgkcyteMkryGjb/UTwKvA65Octu4qYbXocCfBK697OsT822HUpKXsVHe91TVfWPnGdmNwLuSPMrGpbW3J/niuJFGdRG4WFW/+avsXjYK/TB6B/DDqppV1a+B+4C3jJxpcB0K/L+A1yc5meTlbLwQ8dWRM40iSdi4vnmhqj41dp6xVdXHqupEVa2y8f/iW1V14FZZu1VVTwFPJHnDfNNp4OERI43pceCGJFfNf29OcwBf0D06doCdVNWzSf4C+Bc2Xkn+XFV9d+RYY7kR+FPgoSTfmW/7eFV9bbxImpi/BO6ZL3YeAW4fOc8oqur+JPcCD7Bx99aDHMC31PtWeklqqsMlFEnSFVjgktSUBS5JTVngktSUBS5JTVngktSUBS5JTf0/DIZ7SVVTrxQAAAAASUVORK5CYII=\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "import torchvision\n", + "import numpy as np\n", + "import matplotlib.pyplot as plt\n", + "%matplotlib inline\n", + "\n", + "train = torchvision.datasets.MNIST(root='../datasets/', train=True, download=True)\n", + "x = train.data.numpy()\n", + "x = np.expand_dims(x, 1)\n", + "t = train.targets.numpy() + 1\n", + "\n", + "# print(x.shape, t.shape)\n", + "\n", + "# test = torchvision.datasets.MNIST(root='../datasets/', train=False, download=True)\n", + "# x = test.data.numpy()\n", + "# x = np.expand_dims(x, 1)\n", + "# t = test.targets.numpy() + 1\n", + "\n", + "e, t = datasets.increase_censoring(np.ones(t.shape),t,.5)\n", + "plt.scatter(train.targets.numpy(),t)\n", + "\n", + "print(x.shape, t.shape, e.shape)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(9105, 44) (9105,) (9105,)\n" + ] + } + ], + "source": [ + "x, t, e = datasets.load_dataset('SUPPORT')\n", + "print(x.shape, t.shape, e.shape)\n", + "x = np.random.random((9105,1,100,100))\n", + "\n", + "times = np.quantile(t[e==1], [0.25, 0.5, 0.75]).tolist()\n", + "\n", + "cv_folds = 5\n", + "folds = list(range(cv_folds))*10000\n", + "folds = np.array(folds[:len(x)])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "On Fold: 0\n", + "(7284, 1, 100, 100)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/zfsauton2/home/chufang/anaconda3/lib/python3.8/site-packages/torch/nn/modules/container.py:550: UserWarning: Setting attributes on ParameterDict is not supported.\n", + " warnings.warn(\"Setting attributes on ParameterDict is not supported.\")\n", + " 0%| | 0/10000 [00:00\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 22\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 23\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mi\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtimes\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 24\u001b[0;31m \u001b[0mcis_\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mconcordance_index_ipcw\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0met_train\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0met_test\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mout_risk\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtimes\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 25\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 26\u001b[0m \u001b[0mcis\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcis_\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/anaconda3/lib/python3.8/site-packages/sksurv/metrics.py\u001b[0m in \u001b[0;36mconcordance_index_ipcw\u001b[0;34m(survival_train, survival_test, estimate, tau, tied_tol)\u001b[0m\n\u001b[1;32m 300\u001b[0m \u001b[0msurvival_test\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0msurvival_test\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mmask\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 301\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 302\u001b[0;31m \u001b[0mestimate\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_check_estimate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mestimate\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtest_time\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 303\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 304\u001b[0m \u001b[0mcens\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mCensoringDistributionEstimator\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/anaconda3/lib/python3.8/site-packages/sksurv/metrics.py\u001b[0m in \u001b[0;36m_check_estimate\u001b[0;34m(estimate, test_time)\u001b[0m\n\u001b[1;32m 29\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 30\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_check_estimate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mestimate\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtest_time\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 31\u001b[0;31m \u001b[0mestimate\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcheck_array\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mestimate\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mensure_2d\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 32\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mestimate\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mndim\u001b[0m \u001b[0;34m!=\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 33\u001b[0m raise ValueError(\n", + "\u001b[0;32m~/anaconda3/lib/python3.8/site-packages/sklearn/utils/validation.py\u001b[0m in \u001b[0;36minner_f\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 71\u001b[0m FutureWarning)\n\u001b[1;32m 72\u001b[0m \u001b[0mkwargs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mupdate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m{\u001b[0m\u001b[0mk\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0marg\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mk\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0marg\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mzip\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msig\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mparameters\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 73\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mf\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 74\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0minner_f\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 75\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/anaconda3/lib/python3.8/site-packages/sklearn/utils/validation.py\u001b[0m in \u001b[0;36mcheck_array\u001b[0;34m(array, accept_sparse, accept_large_sparse, dtype, order, copy, force_all_finite, ensure_2d, allow_nd, ensure_min_samples, ensure_min_features, estimator)\u001b[0m\n\u001b[1;32m 597\u001b[0m \u001b[0marray\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0marray\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mastype\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdtype\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcasting\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m\"unsafe\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcopy\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 598\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 599\u001b[0;31m \u001b[0marray\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0masarray\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0marray\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0morder\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0morder\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdtype\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdtype\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 600\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mComplexWarning\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 601\u001b[0m raise ValueError(\"Complex data not supported\\n\"\n", + "\u001b[0;32m~/anaconda3/lib/python3.8/site-packages/numpy/core/_asarray.py\u001b[0m in \u001b[0;36masarray\u001b[0;34m(a, dtype, order)\u001b[0m\n\u001b[1;32m 83\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 84\u001b[0m \"\"\"\n\u001b[0;32m---> 85\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0marray\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0ma\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdtype\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcopy\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0morder\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0morder\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 86\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 87\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mTypeError\u001b[0m: float() argument must be a string or a number, not 'StepFunction'" + ] + } + ], + "source": [ + "cis = []\n", + "for fold in range(cv_folds):\n", + " \n", + " print (\"On Fold:\", fold)\n", + " \n", + " x_train, t_train, e_train = x[folds!=fold], t[folds!=fold], e[folds!=fold]\n", + " x_test, t_test, e_test = x[folds==fold], t[folds==fold], e[folds==fold]\n", + " \n", + " et_train = np.array([(e_train[i], t_train[i]) for i in range(len(e_train))],\n", + " dtype=[('e', bool), ('t', int)])\n", + " et_test = np.array([(e_test[i], t_test[i]) for i in range(len(e_test))],\n", + " dtype=[('e', bool), ('t', int)])\n", + " \n", + " model = CoxPHSurvivalAnalysis(alpha=1e-3)\n", + " model.fit(x_test, et_test)\n", + "\n", + " out_risk = model.predict_survival_function(x_test)\n", + " \n", + " cis_ = []\n", + " for i in range(len(times)):\n", + " cis_.append(concordance_index_ipcw(et_train, et_test, out_risk, times[i])[0])\n", + " cis.append(cis_)" + ] + }, + { + "cell_type": "code", + "execution_count": 54, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "3" + ] + }, + "execution_count": 54, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "time = 6\n", + "int(np.where(out_risk[0].x == time)[0])" + ] + }, + { + "cell_type": "code", + "execution_count": 50, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([ 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,\n", + " 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24,\n", + " 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,\n", + " 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46,\n", + " 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57,\n", + " 58, 59, 60, 62, 63, 64, 65, 66, 67, 68, 69,\n", + " 70, 71, 72, 74, 75, 77, 78, 79, 80, 81, 82,\n", + " 83, 84, 85, 86, 88, 90, 91, 92, 93, 94, 95,\n", + " 96, 97, 98, 100, 101, 102, 103, 104, 105, 106, 107,\n", + " 108, 109, 110, 111, 112, 114, 116, 117, 118, 119, 120,\n", + " 121, 122, 124, 126, 127, 128, 129, 130, 132, 133, 134,\n", + " 136, 137, 139, 142, 143, 145, 146, 147, 148, 149, 151,\n", + " 152, 153, 156, 157, 160, 162, 163, 164, 165, 166, 167,\n", + " 168, 170, 171, 172, 173, 174, 176, 180, 181, 182, 183,\n", + " 185, 186, 187, 189, 191, 193, 194, 195, 197, 198, 199,\n", + " 200, 201, 202, 203, 204, 205, 207, 208, 212, 213, 214,\n", + " 215, 217, 218, 220, 223, 224, 225, 227, 229, 230, 231,\n", + " 233, 234, 235, 236, 237, 240, 242, 244, 247, 248, 251,\n", + " 252, 253, 254, 258, 259, 260, 263, 264, 265, 266, 268,\n", + " 269, 273, 274, 276, 277, 279, 281, 283, 287, 288, 290,\n", + " 291, 292, 294, 295, 297, 299, 300, 303, 309, 310, 311,\n", + " 312, 313, 314, 316, 318, 319, 320, 321, 322, 323, 324,\n", + " 326, 328, 330, 335, 338, 339, 340, 343, 344, 346, 347,\n", + " 348, 350, 351, 352, 353, 355, 356, 359, 360, 361, 363,\n", + " 365, 366, 368, 370, 372, 377, 379, 380, 381, 382, 384,\n", + " 385, 386, 387, 389, 392, 393, 394, 395, 396, 397, 399,\n", + " 400, 401, 403, 404, 405, 406, 407, 408, 409, 410, 411,\n", + " 413, 415, 417, 418, 420, 421, 422, 423, 425, 428, 430,\n", + " 432, 433, 434, 435, 436, 440, 442, 444, 446, 447, 448,\n", + " 449, 450, 451, 453, 455, 459, 460, 461, 463, 464, 465,\n", + " 467, 468, 469, 470, 472, 473, 474, 477, 479, 482, 484,\n", + " 485, 486, 487, 489, 491, 492, 493, 494, 496, 497, 499,\n", + " 500, 501, 503, 504, 507, 509, 511, 513, 515, 517, 518,\n", + " 521, 523, 524, 526, 527, 528, 529, 531, 533, 534, 536,\n", + " 541, 542, 546, 548, 551, 552, 553, 554, 555, 557, 558,\n", + " 560, 562, 563, 564, 566, 567, 573, 575, 576, 577, 578,\n", + " 582, 584, 585, 586, 587, 588, 589, 591, 595, 597, 599,\n", + " 603, 604, 605, 608, 609, 610, 613, 615, 616, 617, 618,\n", + " 619, 620, 621, 623, 624, 626, 627, 628, 629, 631, 632,\n", + " 633, 634, 636, 637, 641, 643, 644, 648, 649, 650, 652,\n", + " 653, 655, 656, 657, 658, 659, 661, 662, 664, 665, 666,\n", + " 667, 668, 669, 670, 671, 674, 675, 677, 679, 680, 682,\n", + " 685, 686, 690, 692, 695, 702, 703, 705, 706, 707, 708,\n", + " 709, 710, 712, 714, 716, 717, 719, 720, 721, 724, 726,\n", + " 727, 734, 738, 741, 744, 745, 746, 747, 751, 756, 757,\n", + " 760, 761, 763, 765, 766, 768, 770, 772, 773, 774, 776,\n", + " 777, 779, 781, 783, 784, 786, 789, 790, 795, 797, 798,\n", + " 799, 800, 803, 804, 807, 808, 809, 811, 812, 814, 815,\n", + " 816, 817, 818, 819, 820, 821, 823, 824, 825, 827, 829,\n", + " 830, 831, 833, 835, 839, 842, 844, 845, 847, 849, 851,\n", + " 852, 853, 855, 857, 858, 861, 864, 867, 868, 869, 872,\n", + " 873, 875, 877, 878, 879, 883, 885, 887, 889, 890, 891,\n", + " 892, 897, 898, 904, 910, 914, 917, 918, 919, 923, 926,\n", + " 928, 929, 930, 934, 936, 937, 940, 941, 944, 946, 950,\n", + " 951, 954, 958, 965, 969, 970, 971, 972, 973, 977, 978,\n", + " 982, 984, 985, 986, 987, 988, 989, 992, 996, 998, 999,\n", + " 1000, 1006, 1009, 1011, 1012, 1017, 1018, 1021, 1022, 1023, 1029,\n", + " 1034, 1036, 1037, 1043, 1045, 1046, 1047, 1049, 1050, 1051, 1055,\n", + " 1059, 1060, 1063, 1064, 1068, 1070, 1072, 1073, 1074, 1075, 1078,\n", + " 1079, 1082, 1087, 1088, 1099, 1109, 1116, 1126, 1134, 1138, 1142,\n", + " 1162, 1164, 1172, 1174, 1177, 1185, 1201, 1212, 1213, 1224, 1227,\n", + " 1232, 1238, 1250, 1253, 1265, 1269, 1289, 1301, 1304, 1307, 1310,\n", + " 1312, 1320, 1321, 1326, 1327, 1328, 1342, 1344, 1345, 1346, 1347,\n", + " 1349, 1352, 1355, 1356, 1360, 1363, 1365, 1369, 1371, 1373, 1377,\n", + " 1379, 1380, 1382, 1384, 1385, 1388, 1391, 1392, 1396, 1398, 1401,\n", + " 1406, 1409, 1410, 1411, 1416, 1418, 1421, 1422, 1427, 1439, 1441,\n", + " 1442, 1444, 1449, 1452, 1455, 1458, 1466, 1467, 1474, 1475, 1484,\n", + " 1485, 1486, 1487, 1489, 1492, 1495, 1497, 1503, 1510, 1512, 1514,\n", + " 1517, 1518, 1519, 1521, 1530, 1531, 1534, 1539, 1542, 1543, 1547,\n", + " 1551, 1552, 1558, 1560, 1563, 1566, 1567, 1568, 1572, 1573, 1578,\n", + " 1579, 1593, 1596, 1599, 1600, 1605, 1610, 1613, 1614, 1618, 1622,\n", + " 1623, 1629, 1636, 1642, 1647, 1648, 1654, 1655, 1657, 1659, 1665,\n", + " 1670, 1671, 1676, 1677, 1681, 1683, 1686, 1688, 1689, 1691, 1697,\n", + " 1699, 1701, 1705, 1712, 1717, 1718, 1719, 1722, 1723, 1728, 1732,\n", + " 1733, 1734, 1739, 1740, 1742, 1745, 1747, 1748, 1761, 1763, 1767,\n", + " 1769, 1772, 1778, 1782, 1783, 1785, 1788, 1790, 1792, 1795, 1798,\n", + " 1801, 1807, 1812, 1814, 1819, 1820, 1823, 1825, 1826, 1830, 1845,\n", + " 1853, 1857, 1863, 1866, 1867, 1882, 1885, 1886, 1887, 1892, 1910,\n", + " 1911, 1915, 1916, 1918, 1921, 1928, 1938, 1940, 1944, 1945, 1948,\n", + " 1949, 1951, 1952, 1963, 1971, 1976, 1978, 1979, 1980, 1984, 1990,\n", + " 1992, 1995, 1998, 1999, 2001, 2007, 2009, 2010, 2012, 2014, 2016,\n", + " 2019, 2022, 2024, 2026, 2027, 2028, 2029])" + ] + }, + "execution_count": 50, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "out_risk[0].x" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "model = CoxPHSurvivalAnalysis(alpha=1e-3)\n", + "model.fit(x_test, et_test)" + ] + }, + { + "cell_type": "code", + "execution_count": 95, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([0.74335312, 0.7045087 , 0.68096073])" + ] + }, + "execution_count": 95, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "np.mean(cis,axis=0)" + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "metadata": {}, + "outputs": [], + "source": [ + "out_risk = model.predict_risk(x, times)" + ] + }, + { + "cell_type": "code", + "execution_count": 32, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "DeepSurvivalMachinesTorch(\n", + " (act): SELU()\n", + " (embedding): Sequential(\n", + " (0): Linear(in_features=44, out_features=100, bias=False)\n", + " (1): ReLU6()\n", + " (2): Linear(in_features=100, out_features=100, bias=False)\n", + " (3): ReLU6()\n", + " )\n", + " (gate): Sequential(\n", + " (0): Linear(in_features=100, out_features=3, bias=False)\n", + " )\n", + " (scaleg): Sequential(\n", + " (0): Linear(in_features=100, out_features=3, bias=True)\n", + " )\n", + " (shapeg): Sequential(\n", + " (0): Linear(in_features=100, out_features=3, bias=True)\n", + " )\n", + ")" + ] + }, + "execution_count": 32, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model.torch_model.eval()" + ] + }, + { + "cell_type": "code", + "execution_count": 33, + "metadata": {}, + "outputs": [], + "source": [ + "out_survival = model.predict_survival(x, times)" + ] + }, + { + "cell_type": "code", + "execution_count": 34, + "metadata": {}, + "outputs": [], + "source": [ + "from matplotlib import pyplot as plt\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 35, + "metadata": {}, + "outputs": [], + "source": [ + "from sksurv.metrics import brier_score, concordance_index_ipcw" + ] + }, + { + "cell_type": "code", + "execution_count": 36, + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "import numpy as np\n", + "et = np.array([(e[i], t[i]) for i in range(len(e))],\n", + " dtype=[('e', bool), ('t', int)])\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 37, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([0.13039755, 0.20234974, 0.21643684])" + ] + }, + "execution_count": 37, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "brier_score(et, et, out_survival, times )" + ] + }, + { + "cell_type": "code", + "execution_count": 38, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0.7519513749695589\n", + "0.7074775823879251\n", + "0.678728630898966\n" + ] + } + ], + "source": [ + "for i in range(len(times)):\n", + " print(concordance_index_ipcw(et, et, out_risk[:,i], times[i])[0])" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [], + "source": [ + "from sksurv.linear_model import CoxPHSurvivalAnalysis\n", + "\n", + "estimator = CoxPHSurvivalAnalysis(alpha=1e-3).fit(x, et,)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [], + "source": [ + "surv_funcs = estimator.predict(x)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([ 0.86249313, 0.16849345, -0.45380257, ..., -0.14997697,\n", + " 0.35619347, -0.12209867])" + ] + }, + "execution_count": 22, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "surv_funcs" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0.6924659134706312\n", + "0.6741630293711603\n", + "0.6724802772351569\n" + ] + } + ], + "source": [ + "for i in range(len(times)):\n", + " print(concordance_index_ipcw(et, et, surv_funcs, times[i])[0])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.3" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/examples/conv_example.py b/examples/conv_example.py new file mode 100644 index 0000000..7cac828 --- /dev/null +++ b/examples/conv_example.py @@ -0,0 +1,51 @@ +import importlib +import os, sys +sys.path.insert(1, os.path.join(sys.path[0], '..')) + +from dsm import datasets, DeepSurvivalMachines, DeepConvolutionalSurvivalMachines +import numpy as np +from sksurv.metrics import concordance_index_ipcw, brier_score + +x, t, e = datasets.load_dataset('MNIST') +print(x.shape, t.shape, e.shape) +# x = np.random.random((9105,1,100,100)) + +times = np.quantile(t[e==1], [0.25, 0.5, 0.75]).tolist() + +cv_folds = 6 +folds = list(range(cv_folds))*10000 +folds = np.array(folds[:len(x)]) + +cis = [] +brs = [] +for fold in range(cv_folds): + + print ("On Fold:", fold) + + x_train, t_train, e_train = x[folds!=fold], t[folds!=fold], e[folds!=fold] + x_test, t_test, e_test = x[folds==fold], t[folds==fold], e[folds==fold] + + print (x_train.shape) + +# model = DeepSurvivalMachines(distribution='Weibull', layers=[100]) + model = DeepConvolutionalSurvivalMachines(distribution='Weibull', hidden=64) + model.fit(x_train, t_train, e_train, iters=2, learning_rate=1e-3, batch_size=101) + + et_train = np.array([(e_train[i], t_train[i]) for i in range(len(e_train))], + dtype=[('e', bool), ('t', int)]) + + et_test = np.array([(e_test[i], t_test[i]) for i in range(len(e_test))], + dtype=[('e', bool), ('t', int)]) + + out_risk = model.predict_risk(x_test, times) + out_survival = model.predict_survival(x_test, times) + + cis_ = [] + for i in range(len(times)): + cis_.append(concordance_index_ipcw(et_train, et_test, out_risk[:,i], times[i])[0]) + cis.append(cis_) + + brs.append(brier_score(et_train, et_test, out_survival, times )[1]) + +print ("Concordance Index:", np.mean(cis,axis=0)) +print ("Brier Score:", np.mean(brs,axis=0)) diff --git a/examples/datasets/MNIST/processed/test.pt b/examples/datasets/MNIST/processed/test.pt new file mode 100644 index 0000000..4fb0144 Binary files /dev/null and b/examples/datasets/MNIST/processed/test.pt differ diff --git a/examples/datasets/MNIST/processed/training.pt b/examples/datasets/MNIST/processed/training.pt new file mode 100644 index 0000000..aa5a7a2 Binary files /dev/null and b/examples/datasets/MNIST/processed/training.pt differ diff --git a/examples/datasets/MNIST/raw/t10k-images-idx3-ubyte b/examples/datasets/MNIST/raw/t10k-images-idx3-ubyte new file mode 100644 index 0000000..1170b2c Binary files /dev/null and b/examples/datasets/MNIST/raw/t10k-images-idx3-ubyte differ diff --git a/examples/datasets/MNIST/raw/t10k-images-idx3-ubyte.gz b/examples/datasets/MNIST/raw/t10k-images-idx3-ubyte.gz new file mode 100644 index 0000000..5ace8ea Binary files /dev/null and b/examples/datasets/MNIST/raw/t10k-images-idx3-ubyte.gz differ diff --git a/examples/datasets/MNIST/raw/t10k-labels-idx1-ubyte b/examples/datasets/MNIST/raw/t10k-labels-idx1-ubyte new file mode 100644 index 0000000..d1c3a97 Binary files /dev/null and b/examples/datasets/MNIST/raw/t10k-labels-idx1-ubyte differ diff --git a/examples/datasets/MNIST/raw/t10k-labels-idx1-ubyte.gz b/examples/datasets/MNIST/raw/t10k-labels-idx1-ubyte.gz new file mode 100644 index 0000000..a7e1415 Binary files /dev/null and b/examples/datasets/MNIST/raw/t10k-labels-idx1-ubyte.gz differ diff --git a/examples/datasets/MNIST/raw/train-images-idx3-ubyte b/examples/datasets/MNIST/raw/train-images-idx3-ubyte new file mode 100644 index 0000000..bbce276 Binary files /dev/null and b/examples/datasets/MNIST/raw/train-images-idx3-ubyte differ diff --git a/examples/datasets/MNIST/raw/train-images-idx3-ubyte.gz b/examples/datasets/MNIST/raw/train-images-idx3-ubyte.gz new file mode 100644 index 0000000..b50e4b6 Binary files /dev/null and b/examples/datasets/MNIST/raw/train-images-idx3-ubyte.gz differ diff --git a/examples/datasets/MNIST/raw/train-labels-idx1-ubyte b/examples/datasets/MNIST/raw/train-labels-idx1-ubyte new file mode 100644 index 0000000..d6b4c5d Binary files /dev/null and b/examples/datasets/MNIST/raw/train-labels-idx1-ubyte differ diff --git a/examples/datasets/MNIST/raw/train-labels-idx1-ubyte.gz b/examples/datasets/MNIST/raw/train-labels-idx1-ubyte.gz new file mode 100644 index 0000000..707a576 Binary files /dev/null and b/examples/datasets/MNIST/raw/train-labels-idx1-ubyte.gz differ