From a33b74376749f76913031313e6f5ac7e320a7fba Mon Sep 17 00:00:00 2001 From: Chirag Nagpal Date: Sat, 2 Jan 2021 14:23:06 +0530 Subject: [PATCH] modified: dsm/dsm_api.py modified: dsm/dsm_torch.py modified: examples/conv_example.ipynb modified: dsm/utilities.py modified: examples/conv_example.ipynb modified: examples/pbc_final_experiment.ipynb --- dsm/dsm_api.py | 2 +- dsm/dsm_torch.py | 35 +-- examples/conv_example.ipynb | 411 +++++++++++++++++++++++++++++++++--- 3 files changed, 407 insertions(+), 41 deletions(-) diff --git a/dsm/dsm_api.py b/dsm/dsm_api.py index 8298658..88ea114 100644 --- a/dsm/dsm_api.py +++ b/dsm/dsm_api.py @@ -338,7 +338,7 @@ def __init__(self, k=3, layers=None, hidden=None, distribution='Weibull', temp=1000., discount=1.0, typ='ConvNet'): super(DeepConvolutionalSurvivalMachines, self).__init__(k=k, distribution=distribution, - temp=temp, + temp=temp, discount=discount) self.hidden = hidden self.typ = typ diff --git a/dsm/dsm_torch.py b/dsm/dsm_torch.py index 7f42657..21296c4 100644 --- a/dsm/dsm_torch.py +++ b/dsm/dsm_torch.py @@ -40,7 +40,8 @@ __pdoc__ = {} for clsn in ['DeepSurvivalMachinesTorch', - 'DeepRecurrentSurvivalMachinesTorch']: + 'DeepRecurrentSurvivalMachinesTorch', + 'DeepConvolutionalSurvivalMachines']: for membr in ['training', 'dump_patches']: __pdoc__[clsn+'.'+membr] = False @@ -370,19 +371,27 @@ def create_conv_representation(inputdim, hidden, typ='ConvNet'): linear_dim = ((((inputdim-2) // 2) - 2) // 2) ** 2 linear_dim *= 16 embedding = nn.Sequential( - nn.Conv2d(1, 6, 3), - nn.ReLU(), - nn.MaxPool2d(2, 2), - nn.Conv2d(6, 16, 3), - nn.ReLU(), - nn.MaxPool2d(2, 2), - nn.Flatten(), - nn.Linear(linear_dim, 120), - nn.ReLU(), - nn.Linear(120, 84), - nn.ReLU(), - nn.Linear(84, hidden) + nn.Conv2d(1, 6, 3), + nn.ReLU6(), + nn.MaxPool2d(2, 2), + nn.Conv2d(6, 16, 3), + nn.ReLU6(), + nn.MaxPool2d(2, 2), + nn.Flatten(), + nn.Linear(linear_dim, hidden), + nn.ReLU6() ) + +# if typ == 'SimpleConvNet': +# inputdim = np.squeeze(inputdim) + +# layers.Conv2D(32, kernel_size=(3, 3), activation="relu"), +# layers.MaxPooling2D(pool_size=(2, 2)), +# layers.Conv2D(64, kernel_size=(3, 3), activation="relu"), +# layers.MaxPooling2D(pool_size=(2, 2)), +# layers.Flatten(), + + return embedding class DeepConvolutionalSurvivalMachinesTorch(nn.Module): diff --git a/examples/conv_example.ipynb b/examples/conv_example.ipynb index febe8cb..1f9d93a 100644 --- a/examples/conv_example.ipynb +++ b/examples/conv_example.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": 6, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ @@ -17,7 +17,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 2, "metadata": {}, "outputs": [], "source": [ @@ -28,7 +28,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 3, "metadata": { "scrolled": true }, @@ -42,7 +42,7 @@ }, { "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXAAAAD4CAYAAAD1jb0+AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/d3fzzAAAACXBIWXMAAAsTAAALEwEAmpwYAAAL6klEQVR4nO3dX4ild33H8fenu0pNBP+wQ6i7obMXoiwBGxlKaoqI64UlwfTKRkixqbI3tkaxSPQm9M4LEXNRhGXVWgzaEgOVVqyglbZQQmcTITWrIDF/Nm6aI/Ufgo0h317MkexMdjIz5zyZ5/nOvF83e+aZZ87z4WHnw2+e85zzTVUhSernt8YOIElajAUuSU1Z4JLUlAUuSU1Z4JLU1NH9PNixY8dqdXV1Pw8pSe2dP3/+x1W1snX7vhb46uoq6+vr+3lISWovyWNX2u4lFElqygKXpKYscElqygKXpKYscElqase7UJJ8DrgZeLqqrptvey3w98Aq8Cjw7qr6yUsXU5L6Wb3zn1+w7dFP3DTY8+9mBf63wDu3bLsT+GZVvR745vxrSdLclcr7xbYvYscCr6p/A/53y+ZbgC/MH38B+OPBEkmSdmXRa+DXVNWl+eOngGu22zHJmSTrSdZns9mCh5MkbbX0i5i1MRFi26kQVXW2qtaqam1l5QXvBJUkLWjRAv+fJL8DMP/36eEiSZJ2Y9EC/yrw3vnj9wL/OEwcSToYtrvbZMi7UHZzG+GXgLcBx5JcBO4CPgH8Q5L3AY8B7x4skSQdEEOW9ZXsWOBV9Z5tvnV64CySpD3wnZiS1JQFLklNWeCS1JQFLklNWeCS1JQFLklNWeCS1JQFLklNWeCS1JQFLklNWeCS1JQFLklNWeCS1NSOn0YoSR291BPhp5DBFbikA2c/JsJPIYMFLklNWeCS1JQFLklNWeCS1JQFLunA2Y+J8FPIkKoa7Ml2sra2Vuvr6/t2PEk6CJKcr6q1rdtdgUtSUxa4JDVlgUtSUxa4JDVlgUtSUxa4JDVlgUtSUxa4JDVlgUtSUxa4JDVlgUtSUxa4JDVlgUtSUxa4JDW11FT6JB8G3g8U8BBwe1X9aohgknqawjT4qeSY7FT6JMeBDwJrVXUdcAS4dahgkvqZwjT4qeToMJX+KPCKJEeBq4AfLR9JkrQbCxd4VT0JfBJ4HLgE/KyqvrF1vyRnkqwnWZ/NZosnlSRtsswllNcAtwAngdcBVye5bet+VXW2qtaqam1lZWXxpJKkTZa5hPIO4IdVNauqXwP3AW8ZJpYkaSfLFPjjwA1JrkoS4DRwYZhYkjqawjT4qeSY/FT6JH8N/AnwLPAg8P6q+r/t9ncqvSTt3XZT6Ze6D7yq7gLuWuY5JEmL8Z2YktSUBS5JTVngktSUBS5JTVngktSUBS5JTVngktSUBS5JTVngktSUBS5JTVngktSUBS5JTVngktTUUp9GKGk6pjCF3Rz7m8EVuHQATGEKuzn2P4MFLklNWeCS1JQFLklNWeCS1JQFLh0AU5jCbo79z7DUVPq9ciq9JO3ddlPpXYFLUlMWuCQ1ZYFLUlMWuCQ1ZYFLUlMWuCQ1ZYFLUlMWuCQ1ZYFLUlMWuCQ1ZYFLUlMWuCQ1ZYFLUlMWuCQ1tdRU+iSvBs4B1wEF/HlV/ecAuaQ2pjD93BzTzDH1qfR3A1+vqjcCbwIuLB9J6mMK08/NMc0c+5Fh4RV4klcBbwX+DKCqngGeGSaWJGkny6zATwIz4PNJHkxyLsnVW3dKcibJepL12Wy2xOEkSZdbpsCPAm8GPlNV1wO/BO7culNVna2qtapaW1lZWeJwkqTLLVPgF4GLVXX//Ot72Sh0SdI+WLjAq+op4Ikkb5hvOg08PEgqqYkpTD83xzRzTH4qfZLfY+M2wpcDjwC3V9VPttvfqfSStHfbTaVf6j7wqvoO8IInlSS99HwnpiQ1ZYFLUlMWuCQ1ZYFLUlMWuCQ1ZYFLUlMWuCQ1ZYFLUlMWuCQ1ZYFLUlMWuCQ1ZYFLUlMWuCQ1tdSnEUpjmsLUcXOYY8wMrsDV0hSmjpvDHGNnsMAlqSkLXJKassAlqSkLXJKassDV0hSmjpvDHGNnWGoq/V45lV6S9m67qfSuwCWpKQtckpqywCWpKQtckpqywCWpKQtckpqywCWpKQtckpqywCWpKQtckpqywCWpKQtckpqywCWpKQtckppaeip9kiPAOvBkVd28fCRN3RSmfZvDHB1ydJhKfwdwYYDnUQNTmPZtDnN0yDH5qfRJTgA3AeeGiSNJ2q1lV+CfBj4KPLfdDknOJFlPsj6bzZY8nCTpNxYu8CQ3A09X1fkX26+qzlbVWlWtraysLHo4SdIWy6zAbwTeleRR4MvA25N8cZBUkqQdLVzgVfWxqjpRVavArcC3quq2wZJpkqYw7dsc5uiQo81U+iRvA/5qp9sInUovSXu33VT6pe8DB6iqbwPfHuK5JEm74zsxJakpC1ySmrLAJakpC1ySmrLAJakpC1ySmrLAJakpC1ySmrLAJakpC1ySmrLAJakpC1ySmrLAJampQT6NUPvnMEzaNoc5DkqODlPptU8Oy6Rtc5jjIOSY/FR6SdJ4LHBJasoCl6SmLHBJasoCb+SwTNo2hzkOQo42U+l3y6n0krR3202ldwUuSU1Z4JLUlAUuSU1Z4JLUlAUuSU1Z4JLUlAUuSU1Z4JLUlAUuSU1Z4JLUlAUuSU1Z4JLUlAUuSU1Z4JLU1MJT6ZNcC/wdcA1QwNmqunuoYFMzhQnXU8kxhQzmMEeHHFOeSv8s8JGqOgXcAHwgyalhYk3LFCZcTyXHFDKYwxwdckx6Kn1VXaqqB+aPfwFcAI4PFUyS9OIGuQaeZBW4Hrj/Ct87k2Q9yfpsNhvicJIkBijwJK8EvgJ8qKp+vvX7VXW2qtaqam1lZWXZw0mS5pYq8CQvY6O876mq+4aJJEnajYULPEmAzwIXqupTw0WanilMuJ5KjilkMIc5OuSY9FT6JH8I/DvwEPDcfPPHq+pr2/2MU+klae+2m0q/8H3gVfUfQJZKJUlamO/ElKSmLHBJasoCl6SmLHBJasoCl6SmLHBJasoCl6SmLHBJasoCl6SmLHBJasoCl6SmLHBJasoCl6SmFv40wv0yhcnS5pheBnOYo0OOKU+lf8lNYbK0OaaXwRzm6JBj0lPpJUnjssAlqSkLXJKassAlqalJF/gUJkubY3oZzGGODjkmPZV+EU6ll6S9224q/aRX4JKk7VngktSUBS5JTVngktSUBS5JTe3rXShJZsBjC/74MeDHA8bpzvPxPM/FZp6PzQ7C+fjdqlrZunFfC3wZSdavdBvNYeX5eJ7nYjPPx2YH+Xx4CUWSmrLAJampTgV+duwAE+P5eJ7nYjPPx2YH9ny0uQYuSdqs0wpcknQZC1ySmmpR4EnemeT7SX6Q5M6x84wlybVJ/jXJw0m+m+SOsTNNQZIjSR5M8k9jZxlbklcnuTfJ95JcSPIHY2caS5IPz39P/jvJl5L89tiZhjb5Ak9yBPgb4I+AU8B7kpwaN9VongU+UlWngBuADxzic3G5O4ALY4eYiLuBr1fVG4E3cUjPS5LjwAeBtaq6DjgC3DpuquFNvsCB3wd+UFWPVNUzwJeBW0bONIqqulRVD8wf/4KNX87j46YaV5ITwE3AubGzjC3Jq4C3Ap8FqKpnquqno4Ya11HgFUmOAlcBPxo5z+A6FPhx4InLvr7IIS8tgCSrwPXA/SNHGdungY8Cz42cYwpOAjPg8/NLSueSXD12qDFU1ZPAJ4HHgUvAz6rqG+OmGl6HAtcWSV4JfAX4UFX9fOw8Y0lyM/B0VZ0fO8tEHAXeDHymqq4HfgkcyteMkryGjb/UTwKvA65Octu4qYbXocCfBK697OsT822HUpKXsVHe91TVfWPnGdmNwLuSPMrGpbW3J/niuJFGdRG4WFW/+avsXjYK/TB6B/DDqppV1a+B+4C3jJxpcB0K/L+A1yc5meTlbLwQ8dWRM40iSdi4vnmhqj41dp6xVdXHqupEVa2y8f/iW1V14FZZu1VVTwFPJHnDfNNp4OERI43pceCGJFfNf29OcwBf0D06doCdVNWzSf4C+Bc2Xkn+XFV9d+RYY7kR+FPgoSTfmW/7eFV9bbxImpi/BO6ZL3YeAW4fOc8oqur+JPcCD7Bx99aDHMC31PtWeklqqsMlFEnSFVjgktSUBS5JTVngktSUBS5JTVngktSUBS5JTf0/DIZ7SVVTrxQAAAAASUVORK5CYII=\n", + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXAAAAD4CAYAAAD1jb0+AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAL6klEQVR4nO3dX4ild33H8fenu0pNRLTs9I+7oRNB1BCwkaGkpoi4XlgSjDetEVJsWtkbW6NYJHoTeueFiLkowrJqLQZtiYFKKypopRYk9GwipLotSIybjWtzpFZFaGPItxdzJDuTnZ2Zc56d53znvF83M/PMM+f58DDz4TfPec75pqqQJPXzK2MHkCTNxwKXpKYscElqygKXpKYscElq6uhBHuzYsWO1vr5+kIeUpPbOnj37o6pa2779QAt8fX2dyWRykIeUpPaSfP9y272EIklNWeCS1JQFLklNWeCS1JQFLklN7XoXSpJPArcBT1XVjbNtvwb8HbAOPA78UVX9+OrFlKSe1u/5p+dte/zDtw7y2HtZgf8N8JZt2+4BvlpVrwS+OvtaknSJy5X3lbbv164FXlX/Avz3ts23A5+eff5p4G2DpJEk7dm818B/o6ouAsw+/vpOOyY5lWSSZDKdTuc8nCRpu6v+JGZVna6qjaraWFt73itBJUlzmrfA/yvJbwHMPj41XCRJ0l7MW+BfAN45+/ydwD8ME0eSDo+d7jYZ6i6UvdxG+FngjcCxJBeAe4EPA3+f5M+A88AfDpJGkg6Zocr6cnYt8Kp6xw7fOjlwFknSPvhKTElqygKXpKYscElqygKXpKYscElqygKXpKYscElqygKXpKYscElqygKXpKYscElqygKXpKYscElqatd3I5Skjq7mNPhlyeEKXNKhc7WnwS9LDgtckpqywCWpKQtckpqywCWpKQtc0qFztafBL0uOVNUgD7QXGxsbNZlMDux4knQYJDlbVRvbt7sCl6SmLHBJasoCl6SmLHBJasoCl6SmLHBJasoCl6SmLHBJasoCl6SmLHBJasoCl6SmLHBJasoCl6SmLHBJamqhqfRJ3ge8CyjgUeCuqvrfIYJJ6mkVpsEvS465V+BJjgPvATaq6kbgCHDHIKkktbQq0+CXJceil1COAi9KchS4BvjB4pEkSXsxd4FX1ZPAR4DzwEXgJ1X1le37JTmVZJJkMp1O508qSdpikUsoLwNuB64HXg5cm+TO7ftV1emq2qiqjbW1tfmTSpK2WOQSypuB71XVtKp+ATwIvH6YWJKk3SxS4OeBm5NckyTASeDcMLEkdbQq0+CXJcdCU+mT/BXwduAZ4BHgXVX1fzvt71R6Sdq/nabSL3QfeFXdC9y7yGNIkubjKzElqSkLXJKassAlqSkLXJKassAlqSkLXJKassAlqSkLXJKassAlqSkLXJKassAlqSkLXJKassAlqamF3o1Q0vJYhSns5tjKFbh0CKzKFHZzbGWBS1JTFrgkNWWBS1JTFrgkNWWBS4fAqkxhN8dWC02l3y+n0kvS/u00ld4VuCQ1ZYFLUlMWuCQ1ZYFLUlMWuCQ1ZYFLUlMWuCQ1ZYFLUlMWuCQ1ZYFLUlMWuCQ1ZYFLUlMWuCQ1ZYFLUlMLTaVP8lLgDHAjUMCfVtU3hwgmdbEK08/NsZw5Fl2B3wd8qapeDbwWOLd4JKmPVZl+bo7lzDH3CjzJS4A3AH8CUFVPA08PkkqStKtFVuCvAKbAp5I8kuRMkmu375TkVJJJksl0Ol3gcJKkSy1S4EeB1wEfr6qbgJ8D92zfqapOV9VGVW2sra0tcDhJ0qUWKfALwIWqemj29QNsFrok6QDMXeBV9UPgiSSvmm06CXxnkFRSE6sy/dwcy5ljoan0SX6HzdsIXwg8BtxVVT/eaX+n0kvS/u00lX6h+8Cr6lvA8x5UknT1+UpMSWrKApekpixwSWrKApekpixwSWrKApekpixwSWrKApekpixwSWrKApekpixwSWrKApekpixwSWpqoXcjlMa0ClPHzWGOK3EFrpZWZeq4OcxxJRa4JDVlgUtSUxa4JDVlgUtSUxa4WlqVqePmMMeVLDSVfr+cSi9J+7fTVHpX4JLUlAUuSU1Z4JLUlAUuSU1Z4JLUlAUuSU1Z4JLUlAUuSU1Z4JLUlAUuSU1Z4JLUlAUuSU1Z4JLUlAUuSU0tPJU+yRFgAjxZVbctHkkdLMPE72XIYA5zjJljiBX43cC5AR5HTSzDxO9lyGAOc4ydY6ECT3ICuBU4M0gaSdKeLboC/xjwAeDZnXZIcirJJMlkOp0ueDhJ0i/NXeBJbgOeqqqzV9qvqk5X1UZVbaytrc17OEnSNouswG8B3prkceBzwJuSfGaQVJKkXc1d4FX1wao6UVXrwB3A16rqzsGSaWktw8TvZchgDnOMnWOQqfRJ3gj85W63ETqVXpL2b6ep9AvfBw5QVV8Hvj7EY0mS9sZXYkpSUxa4JDVlgUtSUxa4JDVlgUtSUxa4JDVlgUtSUxa4JDVlgUtSUxa4JDVlgUtSUxa4JDVlgUtSU4O8G6EOzipM2u6UwRzmGDOHK/BGVmXSdpcM5jDH2DkscElqygKXpKYscElqygKXpKYs8EZWZdJ2lwzmMMfYOQaZSr9XTqWXpP3baSq9K3BJasoCl6SmLHBJasoCl6SmLHBJasoCl6SmLHBJasoCl6SmLHBJasoCl6SmLHBJasoCl6SmLHBJasoCl6Sm5p5Kn+Q64G+B3wSeBU5X1X1DBVs2qzDhuluOZchgDnOMmWORFfgzwPur6jXAzcC7k9wwSKolsyoTrjvlWIYM5jDH2DnmLvCqulhVD88+/xlwDjg+SCpJ0q4GuQaeZB24CXjoMt87lWSSZDKdToc4nCSJAQo8yYuBzwPvraqfbv9+VZ2uqo2q2lhbW1v0cJKkmYUKPMkL2Czv+6vqwWEiSZL2Yu4CTxLgE8C5qvrocJGWz6pMuO6UYxkymMMcY+eYeyp9kt8HvgE8yuZthAAfqqov7vQzTqWXpP3baSr93PeBV9W/AlkolSRpbr4SU5KassAlqSkLXJKassAlqSkLXJKassAlqSkLXJKassAlqSkLXJKassAlqSkLXJKassAlqSkLXJKamvvdCA/KKkyWNkffDOYwx5g5lnoFviqTpc3RM4M5zDF2jqUucEnSzixwSWrKApekpixwSWpqqQt8VSZLm6NnBnOYY+wcc0+ln4dT6SVp/3aaSr/UK3BJ0s4scElqygKXpKYscElqygKXpKYO9C6UJFPg+3P++DHgRwPG6c7z8RzPxVaej60Ow/n47apa277xQAt8EUkml7uNZlV5Pp7judjK87HVYT4fXkKRpKYscElqqlOBnx47wJLxfDzHc7GV52OrQ3s+2lwDlyRt1WkFLkm6hAUuSU21KPAkb0nyn0m+m+SesfOMJcl1Sf45ybkk305y99iZlkGSI0keSfKPY2cZW5KXJnkgyX/Mfk9+b+xMY0nyvtnfyb8n+WySXx0709CWvsCTHAH+GvgD4AbgHUluGDfVaJ4B3l9VrwFuBt69wufiUncD58YOsSTuA75UVa8GXsuKnpckx4H3ABtVdSNwBLhj3FTDW/oCB34X+G5VPVZVTwOfA24fOdMoqupiVT08+/xnbP5xHh831biSnABuBc6MnWVsSV4CvAH4BEBVPV1V/zNuqlEdBV6U5ChwDfCDkfMMrkOBHweeuOTrC6x4aQEkWQduAh4aN8noPgZ8AHh27CBL4BXAFPjU7JLSmSTXjh1qDFX1JPAR4DxwEfhJVX1l3FTD61Dgucy2lb73McmLgc8D762qn46dZyxJbgOeqqqzY2dZEkeB1wEfr6qbgJ8DK/mcUZKXsfmf+vXAy4Frk9w5bqrhdSjwC8B1l3x9gkP4r9BeJXkBm+V9f1U9OHaekd0CvDXJ42xeWntTks+MG2lUF4ALVfXL/8oeYLPQV9Gbge9V1bSqfgE8CLx+5EyD61Dg/wa8Msn1SV7I5hMRXxg50yiShM3rm+eq6qNj5xlbVX2wqk5U1Tqbvxdfq6pDt8raq6r6IfBEklfNNp0EvjNipDGdB25Ocs3s7+Ykh/AJ3aNjB9hNVT2T5M+BL7P5TPInq+rbI8cayy3AHwOPJvnWbNuHquqLI2bScvkL4P7ZYucx4K6R84yiqh5K8gDwMJt3bz3CIXxJvS+ll6SmOlxCkSRdhgUuSU1Z4JLUlAUuSU1Z4JLUlAUuSU1Z4JLU1P8D28d6SClli4MAAAAASUVORK5CYII=\n", "text/plain": [ "
" ] @@ -79,29 +79,248 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(array([2925., 3351., 2944., 3025., 2957., 2707., 2959., 3044., 2901.,\n", + " 2975.]),\n", + " array([ 1. , 1.9, 2.8, 3.7, 4.6, 5.5, 6.4, 7.3, 8.2, 9.1, 10. ]),\n", + " )" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAX0AAAD7CAYAAACG50QgAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAASVElEQVR4nO3df4xd5Z3f8fdnbUq8yaIFMSDHtmqaOu0apDVlZLlFqugSFW92VZOVkBypwaqQHCGnJVWkLuSfZP+wxEpJtqXaIDkbimmzICvJCiuF7RJvoigSwTuwbIxxLKxAYWIXz26axvQPNna+/eM+7l6Gy/ywZ+Y6PO+XdHXP/d7nOec5V57PPfOcc8apKiRJffilcQ9AkrRyDH1J6oihL0kdMfQlqSOGviR1xNCXpI7MG/pJ3pPkcJK/SnI0ye+1+meT/CjJ8+3x4aE+9yU5keR4ktuG6jclOdLeeyBJlme3JEmjZL7r9Fswv7eq3khyGfBd4B5gO/BGVX1uVvvNwKPAVuD9wDeBD1bVuSSHW9/vAU8AD1TVk0u8T5Kkd7B6vgY1+FZ4o728rD3m+qbYATxWVW8CLyc5AWxN8gpwRVU9DZDkEeB2YM7Qv/rqq2vjxo3zDVOSNOTZZ5/966qamF2fN/QBkqwCngX+IfCHVfVMkt8EPpHkTmAK+FRV/W9gHYMj+fOmW+1nbXl2fU4bN25kampqIcOUJDVJ/ueo+oJO5FbVuaraAqxncNR+A/Ag8AFgC3AK+Pz5bY1axRz1UYPdnWQqydTMzMxChihJWoBFXb1TVT8Bvg1sr6rX25fBz4EvMZjDh8ER/IahbuuBk62+fkR91Hb2VdVkVU1OTLzttxNJ0gVayNU7E0l+tS2vAT4E/CDJ2qFmHwFeaMsHgZ1JLk9yHbAJOFxVp4AzSba1k8N3Ao8v4b5IkuaxkDn9tcD+Nq//S8CBqvpGkv+aZAuDKZpXgI8DVNXRJAeAF4GzwJ6qOtfWdTfwMLCGwQlcr9yRpBU07yWb4zY5OVmeyJWkxUnybFVNzq57R64kdcTQl6SOGPqS1BFDX5I6sqA7crU4G+/972Pb9iv3/9bYti3p0ueRviR1xNCXpI4Y+pLUEUNfkjpi6EtSRwx9SeqIoS9JHTH0Jakjhr4kdcTQl6SOGPqS1BFDX5I6YuhLUkcMfUnqiKEvSR0x9CWpI4a+JHVk3tBP8p4kh5P8VZKjSX6v1a9K8lSSl9rzlUN97ktyIsnxJLcN1W9KcqS990CSLM9uSZJGWch/l/gm8BtV9UaSy4DvJnkS+B3gUFXdn+Re4F7gd5NsBnYC1wPvB76Z5INVdQ54ENgNfA94AtgOPLnkeyVpWYzrvwL1vwFdOvOGflUV8EZ7eVl7FLADuKXV9wPfBn631R+rqjeBl5OcALYmeQW4oqqeBkjyCHA7hv67gv8vsPSLYUFz+klWJXkeOA08VVXPANdW1SmA9nxNa74OeG2o+3SrrWvLs+ujtrc7yVSSqZmZmcXsjyRpDgsK/ao6V1VbgPUMjtpvmKP5qHn6mqM+anv7qmqyqiYnJiYWMkRJ0gIsZE7//6uqnyT5NoO5+NeTrK2qU0nWMvgtAAZH8BuGuq0HTrb6+hF1SbpkvdvOY8wb+kkmgJ+1wF8DfAj4feAgsAu4vz0/3rocBP44yRcYnMjdBByuqnNJziTZBjwD3An856XeIfXn3fZDuRDjPIeiX2wLOdJfC+xPsorBdNCBqvpGkqeBA0nuAl4F7gCoqqNJDgAvAmeBPe3KHYC7gYeBNQxO4HoSd4kZBpLmspCrd74P3Dii/jfAre/QZy+wd0R9CpjrfMCSMgAl6a28I1eSOrKoE7mSNA7+1r50PNKXpI4Y+pLUEUNfkjpi6EtSRwx9SeqIoS9JHTH0Jakjhr4kdcTQl6SOGPqS1BFDX5I64t/ekS6Qfw9Gv4g80pekjhj6ktQRQ1+SOmLoS1JHDH1J6oihL0kdMfQlqSOGviR1ZN7QT7IhybeSHEtyNMk9rf7ZJD9K8nx7fHioz31JTiQ5nuS2ofpNSY609x5IkuXZLUnSKAu5I/cs8Kmqei7JrwDPJnmqvfcHVfW54cZJNgM7geuB9wPfTPLBqjoHPAjsBr4HPAFsB55cml2RJM1n3iP9qjpVVc+15TPAMWDdHF12AI9V1ZtV9TJwAtiaZC1wRVU9XVUFPALcftF7IElasEXN6SfZCNwIPNNKn0jy/SQPJbmy1dYBrw11m261dW15dn3UdnYnmUoyNTMzs5ghSpLmsODQT/I+4GvAJ6vqpwymaj4AbAFOAZ8/33RE95qj/vZi1b6qmqyqyYmJiYUOUZI0jwWFfpLLGAT+V6rq6wBV9XpVnauqnwNfAra25tPAhqHu64GTrb5+RF2StEIWcvVOgC8Dx6rqC0P1tUPNPgK80JYPAjuTXJ7kOmATcLiqTgFnkmxr67wTeHyJ9kOStAALuXrnZuBjwJEkz7fap4GPJtnCYIrmFeDjAFV1NMkB4EUGV/7saVfuANwNPAysYXDVjlfuSNIKmjf0q+q7jJ6Pf2KOPnuBvSPqU8ANixmgJGnpeEeuJHXE0Jekjhj6ktQRQ1+SOmLoS1JHDH1J6oihL0kdMfQlqSOGviR1xNCXpI4Y+pLUEUNfkjpi6EtSRwx9SeqIoS9JHTH0Jakjhr4kdcTQl6SOGPqS1BFDX5I6YuhLUkcMfUnqyLyhn2RDkm8lOZbkaJJ7Wv2qJE8leak9XznU574kJ5IcT3LbUP2mJEfaew8kyfLsliRplIUc6Z8FPlVVvwZsA/Yk2QzcCxyqqk3Aofaa9t5O4HpgO/DFJKvauh4EdgOb2mP7Eu6LJGke84Z+VZ2qqufa8hngGLAO2AHsb832A7e35R3AY1X1ZlW9DJwAtiZZC1xRVU9XVQGPDPWRJK2ARc3pJ9kI3Ag8A1xbVadg8MUAXNOarQNeG+o23Wrr2vLs+qjt7E4ylWRqZmZmMUOUJM1hwaGf5H3A14BPVtVP52o6olZz1N9erNpXVZNVNTkxMbHQIUqS5rGg0E9yGYPA/0pVfb2VX29TNrTn060+DWwY6r4eONnq60fUJUkrZCFX7wT4MnCsqr4w9NZBYFdb3gU8PlTfmeTyJNcxOGF7uE0BnUmyra3zzqE+kqQVsHoBbW4GPgYcSfJ8q30auB84kOQu4FXgDoCqOprkAPAigyt/9lTVudbvbuBhYA3wZHtIklbIvKFfVd9l9Hw8wK3v0GcvsHdEfQq4YTEDlCQtHe/IlaSOGPqS1BFDX5I6YuhLUkcMfUnqiKEvSR0x9CWpI4a+JHXE0Jekjhj6ktQRQ1+SOmLoS1JHDH1J6oihL0kdMfQlqSOGviR1xNCXpI4Y+pLUEUNfkjpi6EtSRwx9SeqIoS9JHZk39JM8lOR0kheGap9N8qMkz7fHh4feuy/JiSTHk9w2VL8pyZH23gNJsvS7I0may0KO9B8Gto+o/0FVbWmPJwCSbAZ2Ate3Pl9Msqq1fxDYDWxqj1HrlCQto3lDv6q+A/x4gevbATxWVW9W1cvACWBrkrXAFVX1dFUV8Ahw+4UOWpJ0YS5mTv8TSb7fpn+ubLV1wGtDbaZbbV1bnl0fKcnuJFNJpmZmZi5iiJKkYRca+g8CHwC2AKeAz7f6qHn6mqM+UlXtq6rJqpqcmJi4wCFKkma7oNCvqter6lxV/Rz4ErC1vTUNbBhquh442errR9QlSSvogkK/zdGf9xHg/JU9B4GdSS5Pch2DE7aHq+oUcCbJtnbVzp3A4xcxbknSBVg9X4MkjwK3AFcnmQY+A9ySZAuDKZpXgI8DVNXRJAeAF4GzwJ6qOtdWdTeDK4HWAE+2hyRpBc0b+lX10RHlL8/Rfi+wd0R9CrhhUaOTJC0p78iVpI4Y+pLUEUNfkjpi6EtSRwx9SeqIoS9JHTH0Jakjhr4kdcTQl6SOGPqS1BFDX5I6YuhLUkcMfUnqiKEvSR0x9CWpI4a+JHXE0Jekjhj6ktQRQ1+SOmLoS1JHDH1J6oihL0kdmTf0kzyU5HSSF4ZqVyV5KslL7fnKoffuS3IiyfEktw3Vb0pypL33QJIs/e5IkuaykCP9h4Hts2r3AoeqahNwqL0myWZgJ3B96/PFJKtanweB3cCm9pi9TknSMps39KvqO8CPZ5V3APvb8n7g9qH6Y1X1ZlW9DJwAtiZZC1xRVU9XVQGPDPWRJK2QC53Tv7aqTgG052tafR3w2lC76VZb15Zn10dKsjvJVJKpmZmZCxyiJGm2pT6RO2qevuaoj1RV+6pqsqomJyYmlmxwktS7Cw3919uUDe35dKtPAxuG2q0HTrb6+hF1SdIKutDQPwjsasu7gMeH6juTXJ7kOgYnbA+3KaAzSba1q3buHOojSVohq+drkORR4Bbg6iTTwGeA+4EDSe4CXgXuAKiqo0kOAC8CZ4E9VXWurepuBlcCrQGebA9J0gqaN/Sr6qPv8Nat79B+L7B3RH0KuGFRo5MkLSnvyJWkjhj6ktQRQ1+SOmLoS1JHDH1J6oihL0kdMfQlqSOGviR1xNCXpI4Y+pLUEUNfkjpi6EtSRwx9SeqIoS9JHTH0Jakjhr4kdcTQl6SOGPqS1BFDX5I6YuhLUkcMfUnqiKEvSR25qNBP8kqSI0meTzLValcleSrJS+35yqH29yU5keR4ktsudvCSpMVZiiP9f1FVW6pqsr2+FzhUVZuAQ+01STYDO4Hrge3AF5OsWoLtS5IWaDmmd3YA+9vyfuD2ofpjVfVmVb0MnAC2LsP2JUnv4GJDv4A/S/Jskt2tdm1VnQJoz9e0+jrgtaG+0632Nkl2J5lKMjUzM3ORQ5Qknbf6IvvfXFUnk1wDPJXkB3O0zYhajWpYVfuAfQCTk5Mj20iSFu+ijvSr6mR7Pg38CYPpmteTrAVoz6db82lgw1D39cDJi9m+JGlxLjj0k7w3ya+cXwb+JfACcBDY1ZrtAh5vyweBnUkuT3IdsAk4fKHblyQt3sVM71wL/EmS8+v546r60yR/ARxIchfwKnAHQFUdTXIAeBE4C+ypqnMXNXpJ0qJccOhX1Q+BXx9R/xvg1nfosxfYe6HblCRdHO/IlaSOGPqS1BFDX5I6YuhLUkcMfUnqiKEvSR0x9CWpI4a+JHXE0Jekjhj6ktQRQ1+SOmLoS1JHDH1J6oihL0kdMfQlqSOGviR1xNCXpI4Y+pLUEUNfkjpi6EtSRwx9SerIiod+ku1Jjic5keTeld6+JPVsRUM/ySrgD4HfBDYDH02yeSXHIEk9W+kj/a3Aiar6YVX9LfAYsGOFxyBJ3Vrp0F8HvDb0errVJEkrYPUKby8javW2RsluYHd7+UaS48s6quV3NfDX4x7EJcLP4q38PN7Kz6PJ71/0Z/H3RxVXOvSngQ1Dr9cDJ2c3qqp9wL6VGtRySzJVVZPjHselwM/irfw83srP4+8s12ex0tM7fwFsSnJdkr8H7AQOrvAYJKlbK3qkX1Vnk3wC+B/AKuChqjq6kmOQpJ6t9PQOVfUE8MRKb3fM3jVTVUvAz+Kt/Dzeys/j7yzLZ5Gqt51HlSS9S/lnGCSpI4b+MkmyIcm3khxLcjTJPeMe06Ugyaokf5nkG+Mey7gl+dUkX03yg/bv5J+Oe0zjkuTft5+TF5I8muQ94x7TSkryUJLTSV4Yql2V5KkkL7XnK5diW4b+8jkLfKqqfg3YBuzxT04AcA9wbNyDuET8J+BPq+ofA79Op59LknXAvwMmq+oGBhd57BzvqFbcw8D2WbV7gUNVtQk41F5fNEN/mVTVqap6ri2fYfAD3fXdx0nWA78F/NG4xzJuSa4A/jnwZYCq+tuq+sl4RzVWq4E1SVYDv8yI+3fezarqO8CPZ5V3APvb8n7g9qXYlqG/ApJsBG4EnhnvSMbuPwL/Afj5uAdyCfgHwAzwX9p01x8lee+4BzUOVfUj4HPAq8Ap4P9U1Z+Nd1SXhGur6hQMDiKBa5ZipYb+MkvyPuBrwCer6qfjHs+4JPlt4HRVPTvusVwiVgP/BHiwqm4E/i9L9Ov7L5o2V70DuA54P/DeJP96vKN69zL0l1GSyxgE/leq6uvjHs+Y3Qz8qySvMPjrqr+R5L+Nd0hjNQ1MV9X53/6+yuBLoEcfAl6uqpmq+hnwdeCfjXlMl4LXk6wFaM+nl2Klhv4ySRIG87XHquoL4x7PuFXVfVW1vqo2MjhJ9+dV1e3RXFX9L+C1JP+olW4FXhzjkMbpVWBbkl9uPze30ulJ7VkOArva8i7g8aVY6YrfkduRm4GPAUeSPN9qn253JEsA/xb4Svs7VD8E/s2YxzMWVfVMkq8CzzG46u0v6ezO3CSPArcAVyeZBj4D3A8cSHIXgy/GO5ZkW96RK0n9cHpHkjpi6EtSRwx9SeqIoS9JHTH0Jakjhr4kdcTQl6SOGPqS1JH/BwkaH0/JSr3iAAAAAElFTkSuQmCC\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "from matplotlib import pyplot as plt\n", + "\n", + "plt.hist(t[e==1])" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "# x, t, e = datasets.load_dataset('SUPPORT')\n", + "# print(x.shape, t.shape, e.shape)\n", + "# x = np.random.random((9105,1,100,100))\n", + "\n", + "# times = np.quantile(t[e==1], [0.25, 0.5, 0.75]).tolist()\n", + "\n", + "# cv_folds = 5\n", + "# folds = list(range(cv_folds))*10000\n", + "# folds = np.array(folds[:len(x)])" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[1, 10]\n" + ] + }, + { + "ename": "TypeError", + "evalue": "'>' not supported between instances of 'list' and 'int'", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 7\u001b[0m \u001b[0mmarks\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m10\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;36m10\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m20\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m5\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 8\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 9\u001b[0;31m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"Average of mark1:\"\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mavg\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmarks\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", + "\u001b[0;32m\u001b[0m in \u001b[0;36mavg\u001b[0;34m(marks)\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mavg\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmarks\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0mprint\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmin\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmarks\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 3\u001b[0;31m \u001b[0;32massert\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmin\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmarks\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m>\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 4\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mmarks\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mTypeError\u001b[0m: '>' not supported between instances of 'list' and 'int'" + ] + } + ], + "source": [ + "def avg(marks):\n", + " print (np.min(marks))\n", + " assert np.min(marks) > 0\n", + " return marks\n", + "\n", + "\n", + "marks = [[1, 10], [10, 20, 5]]\n", + "\n", + "print(\"Average of mark1:\",avg(marks))\n" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "60000" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "x.shape[0]\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "(9105, 44) (9105,) (9105,)\n" + "(60000, 784) (60000,) (60000,)\n", + "On Fold: 0\n", + "(50000, 784)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 4%|▎ | 355/10000 [00:02<01:21, 118.84it/s]\n", + " 0%| | 0/100 [00:00