diff --git a/examples/SUPPORT_example.ipynb b/examples/SUPPORT_example.ipynb new file mode 100644 index 0000000..bbb74b0 --- /dev/null +++ b/examples/SUPPORT_example.ipynb @@ -0,0 +1,734 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [], + "source": [ + "from dsm import datasets, DeepSurvivalMachines\n", + "import numpy as np" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [], + "source": [ + "x, t, e = datasets.load_dataset('SUPPORT')" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [], + "source": [ + "times = np.quantile(t[e==1], [0.25, 0.5, 0.75]).tolist()" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [], + "source": [ + "cv_folds = 5\n", + "folds = list(range(cv_folds))*10000" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [], + "source": [ + "folds = np.array(folds[:len(x)])" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [], + "source": [ + "from sksurv.metrics import concordance_index_ipcw, brier_score" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 1%| | 54/10000 [00:00<00:18, 536.02it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "On Fold: 0\n", + "(7284, 44)\n", + "Pretraining the Underlying Distributions...\n", + "torch.Size([6192]) torch.Size([6192])\n", + "torch.Size([6192]) torch.Size([6192])\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 13%|█▎ | 1312/10000 [00:02<00:13, 628.77it/s]\n", + " 0%| | 0/100 [00:00\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 22\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 23\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mi\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtimes\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 24\u001b[0;31m \u001b[0mcis_\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mconcordance_index_ipcw\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0met_train\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0met_test\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mout_risk\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtimes\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 25\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 26\u001b[0m \u001b[0mcis\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcis_\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/anaconda3/lib/python3.8/site-packages/sksurv/metrics.py\u001b[0m in \u001b[0;36mconcordance_index_ipcw\u001b[0;34m(survival_train, survival_test, estimate, tau, tied_tol)\u001b[0m\n\u001b[1;32m 300\u001b[0m \u001b[0msurvival_test\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0msurvival_test\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mmask\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 301\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 302\u001b[0;31m \u001b[0mestimate\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_check_estimate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mestimate\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtest_time\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 303\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 304\u001b[0m \u001b[0mcens\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mCensoringDistributionEstimator\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/anaconda3/lib/python3.8/site-packages/sksurv/metrics.py\u001b[0m in \u001b[0;36m_check_estimate\u001b[0;34m(estimate, test_time)\u001b[0m\n\u001b[1;32m 29\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 30\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_check_estimate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mestimate\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtest_time\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 31\u001b[0;31m \u001b[0mestimate\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcheck_array\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mestimate\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mensure_2d\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 32\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mestimate\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mndim\u001b[0m \u001b[0;34m!=\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 33\u001b[0m raise ValueError(\n", + "\u001b[0;32m~/anaconda3/lib/python3.8/site-packages/sklearn/utils/validation.py\u001b[0m in \u001b[0;36minner_f\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 71\u001b[0m FutureWarning)\n\u001b[1;32m 72\u001b[0m \u001b[0mkwargs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mupdate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m{\u001b[0m\u001b[0mk\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0marg\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mk\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0marg\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mzip\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msig\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mparameters\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 73\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mf\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 74\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0minner_f\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 75\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/anaconda3/lib/python3.8/site-packages/sklearn/utils/validation.py\u001b[0m in \u001b[0;36mcheck_array\u001b[0;34m(array, accept_sparse, accept_large_sparse, dtype, order, copy, force_all_finite, ensure_2d, allow_nd, ensure_min_samples, ensure_min_features, estimator)\u001b[0m\n\u001b[1;32m 597\u001b[0m \u001b[0marray\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0marray\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mastype\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdtype\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcasting\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m\"unsafe\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcopy\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 598\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 599\u001b[0;31m \u001b[0marray\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0masarray\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0marray\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0morder\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0morder\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdtype\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdtype\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 600\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mComplexWarning\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 601\u001b[0m raise ValueError(\"Complex data not supported\\n\"\n", + "\u001b[0;32m~/anaconda3/lib/python3.8/site-packages/numpy/core/_asarray.py\u001b[0m in \u001b[0;36masarray\u001b[0;34m(a, dtype, order)\u001b[0m\n\u001b[1;32m 83\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 84\u001b[0m \"\"\"\n\u001b[0;32m---> 85\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0marray\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0ma\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdtype\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcopy\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0morder\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0morder\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 86\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 87\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mTypeError\u001b[0m: float() argument must be a string or a number, not 'StepFunction'" + ] + } + ], + "source": [ + "cis = []\n", + "for fold in range(cv_folds):\n", + " \n", + " print (\"On Fold:\", fold)\n", + " \n", + " x_train, t_train, e_train = x[folds!=fold], t[folds!=fold], e[folds!=fold]\n", + " x_test, t_test, e_test = x[folds==fold], t[folds==fold], e[folds==fold]\n", + " \n", + " et_train = np.array([(e_train[i], t_train[i]) for i in range(len(e_train))],\n", + " dtype=[('e', bool), ('t', int)])\n", + " et_test = np.array([(e_test[i], t_test[i]) for i in range(len(e_test))],\n", + " dtype=[('e', bool), ('t', int)])\n", + " \n", + " model = CoxPHSurvivalAnalysis(alpha=1e-3)\n", + " model.fit(x_test, et_test)\n", + "\n", + " out_risk = model.predict_survival_function(x_test)\n", + " \n", + " cis_ = []\n", + " for i in range(len(times)):\n", + " cis_.append(concordance_index_ipcw(et_train, et_test, out_risk, times[i])[0])\n", + " cis.append(cis_)" + ] + }, + { + "cell_type": "code", + "execution_count": 54, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "3" + ] + }, + "execution_count": 54, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "time = 6\n", + "int(np.where(out_risk[0].x == time)[0])" + ] + }, + { + "cell_type": "code", + "execution_count": 50, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([ 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,\n", + " 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24,\n", + " 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,\n", + " 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46,\n", + " 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57,\n", + " 58, 59, 60, 62, 63, 64, 65, 66, 67, 68, 69,\n", + " 70, 71, 72, 74, 75, 77, 78, 79, 80, 81, 82,\n", + " 83, 84, 85, 86, 88, 90, 91, 92, 93, 94, 95,\n", + " 96, 97, 98, 100, 101, 102, 103, 104, 105, 106, 107,\n", + " 108, 109, 110, 111, 112, 114, 116, 117, 118, 119, 120,\n", + " 121, 122, 124, 126, 127, 128, 129, 130, 132, 133, 134,\n", + " 136, 137, 139, 142, 143, 145, 146, 147, 148, 149, 151,\n", + " 152, 153, 156, 157, 160, 162, 163, 164, 165, 166, 167,\n", + " 168, 170, 171, 172, 173, 174, 176, 180, 181, 182, 183,\n", + " 185, 186, 187, 189, 191, 193, 194, 195, 197, 198, 199,\n", + " 200, 201, 202, 203, 204, 205, 207, 208, 212, 213, 214,\n", + " 215, 217, 218, 220, 223, 224, 225, 227, 229, 230, 231,\n", + " 233, 234, 235, 236, 237, 240, 242, 244, 247, 248, 251,\n", + " 252, 253, 254, 258, 259, 260, 263, 264, 265, 266, 268,\n", + " 269, 273, 274, 276, 277, 279, 281, 283, 287, 288, 290,\n", + " 291, 292, 294, 295, 297, 299, 300, 303, 309, 310, 311,\n", + " 312, 313, 314, 316, 318, 319, 320, 321, 322, 323, 324,\n", + " 326, 328, 330, 335, 338, 339, 340, 343, 344, 346, 347,\n", + " 348, 350, 351, 352, 353, 355, 356, 359, 360, 361, 363,\n", + " 365, 366, 368, 370, 372, 377, 379, 380, 381, 382, 384,\n", + " 385, 386, 387, 389, 392, 393, 394, 395, 396, 397, 399,\n", + " 400, 401, 403, 404, 405, 406, 407, 408, 409, 410, 411,\n", + " 413, 415, 417, 418, 420, 421, 422, 423, 425, 428, 430,\n", + " 432, 433, 434, 435, 436, 440, 442, 444, 446, 447, 448,\n", + " 449, 450, 451, 453, 455, 459, 460, 461, 463, 464, 465,\n", + " 467, 468, 469, 470, 472, 473, 474, 477, 479, 482, 484,\n", + " 485, 486, 487, 489, 491, 492, 493, 494, 496, 497, 499,\n", + " 500, 501, 503, 504, 507, 509, 511, 513, 515, 517, 518,\n", + " 521, 523, 524, 526, 527, 528, 529, 531, 533, 534, 536,\n", + " 541, 542, 546, 548, 551, 552, 553, 554, 555, 557, 558,\n", + " 560, 562, 563, 564, 566, 567, 573, 575, 576, 577, 578,\n", + " 582, 584, 585, 586, 587, 588, 589, 591, 595, 597, 599,\n", + " 603, 604, 605, 608, 609, 610, 613, 615, 616, 617, 618,\n", + " 619, 620, 621, 623, 624, 626, 627, 628, 629, 631, 632,\n", + " 633, 634, 636, 637, 641, 643, 644, 648, 649, 650, 652,\n", + " 653, 655, 656, 657, 658, 659, 661, 662, 664, 665, 666,\n", + " 667, 668, 669, 670, 671, 674, 675, 677, 679, 680, 682,\n", + " 685, 686, 690, 692, 695, 702, 703, 705, 706, 707, 708,\n", + " 709, 710, 712, 714, 716, 717, 719, 720, 721, 724, 726,\n", + " 727, 734, 738, 741, 744, 745, 746, 747, 751, 756, 757,\n", + " 760, 761, 763, 765, 766, 768, 770, 772, 773, 774, 776,\n", + " 777, 779, 781, 783, 784, 786, 789, 790, 795, 797, 798,\n", + " 799, 800, 803, 804, 807, 808, 809, 811, 812, 814, 815,\n", + " 816, 817, 818, 819, 820, 821, 823, 824, 825, 827, 829,\n", + " 830, 831, 833, 835, 839, 842, 844, 845, 847, 849, 851,\n", + " 852, 853, 855, 857, 858, 861, 864, 867, 868, 869, 872,\n", + " 873, 875, 877, 878, 879, 883, 885, 887, 889, 890, 891,\n", + " 892, 897, 898, 904, 910, 914, 917, 918, 919, 923, 926,\n", + " 928, 929, 930, 934, 936, 937, 940, 941, 944, 946, 950,\n", + " 951, 954, 958, 965, 969, 970, 971, 972, 973, 977, 978,\n", + " 982, 984, 985, 986, 987, 988, 989, 992, 996, 998, 999,\n", + " 1000, 1006, 1009, 1011, 1012, 1017, 1018, 1021, 1022, 1023, 1029,\n", + " 1034, 1036, 1037, 1043, 1045, 1046, 1047, 1049, 1050, 1051, 1055,\n", + " 1059, 1060, 1063, 1064, 1068, 1070, 1072, 1073, 1074, 1075, 1078,\n", + " 1079, 1082, 1087, 1088, 1099, 1109, 1116, 1126, 1134, 1138, 1142,\n", + " 1162, 1164, 1172, 1174, 1177, 1185, 1201, 1212, 1213, 1224, 1227,\n", + " 1232, 1238, 1250, 1253, 1265, 1269, 1289, 1301, 1304, 1307, 1310,\n", + " 1312, 1320, 1321, 1326, 1327, 1328, 1342, 1344, 1345, 1346, 1347,\n", + " 1349, 1352, 1355, 1356, 1360, 1363, 1365, 1369, 1371, 1373, 1377,\n", + " 1379, 1380, 1382, 1384, 1385, 1388, 1391, 1392, 1396, 1398, 1401,\n", + " 1406, 1409, 1410, 1411, 1416, 1418, 1421, 1422, 1427, 1439, 1441,\n", + " 1442, 1444, 1449, 1452, 1455, 1458, 1466, 1467, 1474, 1475, 1484,\n", + " 1485, 1486, 1487, 1489, 1492, 1495, 1497, 1503, 1510, 1512, 1514,\n", + " 1517, 1518, 1519, 1521, 1530, 1531, 1534, 1539, 1542, 1543, 1547,\n", + " 1551, 1552, 1558, 1560, 1563, 1566, 1567, 1568, 1572, 1573, 1578,\n", + " 1579, 1593, 1596, 1599, 1600, 1605, 1610, 1613, 1614, 1618, 1622,\n", + " 1623, 1629, 1636, 1642, 1647, 1648, 1654, 1655, 1657, 1659, 1665,\n", + " 1670, 1671, 1676, 1677, 1681, 1683, 1686, 1688, 1689, 1691, 1697,\n", + " 1699, 1701, 1705, 1712, 1717, 1718, 1719, 1722, 1723, 1728, 1732,\n", + " 1733, 1734, 1739, 1740, 1742, 1745, 1747, 1748, 1761, 1763, 1767,\n", + " 1769, 1772, 1778, 1782, 1783, 1785, 1788, 1790, 1792, 1795, 1798,\n", + " 1801, 1807, 1812, 1814, 1819, 1820, 1823, 1825, 1826, 1830, 1845,\n", + " 1853, 1857, 1863, 1866, 1867, 1882, 1885, 1886, 1887, 1892, 1910,\n", + " 1911, 1915, 1916, 1918, 1921, 1928, 1938, 1940, 1944, 1945, 1948,\n", + " 1949, 1951, 1952, 1963, 1971, 1976, 1978, 1979, 1980, 1984, 1990,\n", + " 1992, 1995, 1998, 1999, 2001, 2007, 2009, 2010, 2012, 2014, 2016,\n", + " 2019, 2022, 2024, 2026, 2027, 2028, 2029])" + ] + }, + "execution_count": 50, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "out_risk[0].x" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "model = CoxPHSurvivalAnalysis(alpha=1e-3)\n", + "model.fit(x_test, et_test)" + ] + }, + { + "cell_type": "code", + "execution_count": 95, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([0.74335312, 0.7045087 , 0.68096073])" + ] + }, + "execution_count": 95, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "np.mean(cis,axis=0)" + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "metadata": {}, + "outputs": [], + "source": [ + "out_risk = model.predict_risk(x, times)" + ] + }, + { + "cell_type": "code", + "execution_count": 32, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "DeepSurvivalMachinesTorch(\n", + " (act): SELU()\n", + " (embedding): Sequential(\n", + " (0): Linear(in_features=44, out_features=100, bias=False)\n", + " (1): ReLU6()\n", + " (2): Linear(in_features=100, out_features=100, bias=False)\n", + " (3): ReLU6()\n", + " )\n", + " (gate): Sequential(\n", + " (0): Linear(in_features=100, out_features=3, bias=False)\n", + " )\n", + " (scaleg): Sequential(\n", + " (0): Linear(in_features=100, out_features=3, bias=True)\n", + " )\n", + " (shapeg): Sequential(\n", + " (0): Linear(in_features=100, out_features=3, bias=True)\n", + " )\n", + ")" + ] + }, + "execution_count": 32, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model.torch_model.eval()" + ] + }, + { + "cell_type": "code", + "execution_count": 33, + "metadata": {}, + "outputs": [], + "source": [ + "out_survival = model.predict_survival(x, times)" + ] + }, + { + "cell_type": "code", + "execution_count": 34, + "metadata": {}, + "outputs": [], + "source": [ + "from matplotlib import pyplot as plt\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 35, + "metadata": {}, + "outputs": [], + "source": [ + "from sksurv.metrics import brier_score, concordance_index_ipcw" + ] + }, + { + "cell_type": "code", + "execution_count": 36, + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "import numpy as np\n", + "et = np.array([(e[i], t[i]) for i in range(len(e))],\n", + " dtype=[('e', bool), ('t', int)])\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 37, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([0.13039755, 0.20234974, 0.21643684])" + ] + }, + "execution_count": 37, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "brier_score(et, et, out_survival, times )" + ] + }, + { + "cell_type": "code", + "execution_count": 38, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0.7519513749695589\n", + "0.7074775823879251\n", + "0.678728630898966\n" + ] + } + ], + "source": [ + "for i in range(len(times)):\n", + " print(concordance_index_ipcw(et, et, out_risk[:,i], times[i])[0])" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [], + "source": [ + "from sksurv.linear_model import CoxPHSurvivalAnalysis\n", + "\n", + "estimator = CoxPHSurvivalAnalysis(alpha=1e-3).fit(x, et,)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [], + "source": [ + "surv_funcs = estimator.predict(x)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([ 0.86249313, 0.16849345, -0.45380257, ..., -0.14997697,\n", + " 0.35619347, -0.12209867])" + ] + }, + "execution_count": 22, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "surv_funcs" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0.6924659134706312\n", + "0.6741630293711603\n", + "0.6724802772351569\n" + ] + } + ], + "source": [ + "for i in range(len(times)):\n", + " print(concordance_index_ipcw(et, et, surv_funcs, times[i])[0])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.3" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +}