From 822f369004ec46c36ea7ab27de92b545e6fdbd05 Mon Sep 17 00:00:00 2001 From: Noah Kasmanoff Date: Wed, 19 Jun 2024 13:20:59 +0000 Subject: [PATCH] add date --- notebooks/embedding_visualizer.ipynb | 430 +++++++++++++++------------ 1 file changed, 238 insertions(+), 192 deletions(-) diff --git a/notebooks/embedding_visualizer.ipynb b/notebooks/embedding_visualizer.ipynb index cd4828d..abc2512 100644 --- a/notebooks/embedding_visualizer.ipynb +++ b/notebooks/embedding_visualizer.ipynb @@ -12,7 +12,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 2, "metadata": {}, "outputs": [], "source": [ @@ -29,7 +29,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 3, "metadata": {}, "outputs": [], "source": [ @@ -40,7 +40,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 4, "metadata": {}, "outputs": [ { @@ -78,7 +78,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 5, "metadata": {}, "outputs": [], "source": [ @@ -94,7 +94,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 6, "metadata": {}, "outputs": [ { @@ -112,7 +112,7 @@ }, { "cell_type": "code", - "execution_count": 46, + "execution_count": 7, "metadata": {}, "outputs": [ { @@ -121,7 +121,7 @@ "(Timestamp('2010-05-13 00:00:00'), Timestamp('2010-05-13 00:24:00'))" ] }, - "execution_count": 46, + "execution_count": 7, "metadata": {}, "output_type": "execute_result" } @@ -132,7 +132,7 @@ }, { "cell_type": "code", - "execution_count": 51, + "execution_count": 8, "metadata": {}, "outputs": [], "source": [ @@ -141,7 +141,7 @@ }, { "cell_type": "code", - "execution_count": 53, + "execution_count": 9, "metadata": {}, "outputs": [ { @@ -150,7 +150,7 @@ "(Timestamp('2010-11-01 00:00:00'), Timestamp('2010-11-01 00:24:00'))" ] }, - "execution_count": 53, + "execution_count": 9, "metadata": {}, "output_type": "execute_result" } @@ -161,31 +161,22 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 10, "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/opt/conda/envs/sdofm/lib/python3.10/site-packages/torch/utils/data/dataloader.py:558: UserWarning: This DataLoader will create 16 worker processes in total. Our suggested max number of worker in current system is 8, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.\n", - " warnings.warn(_create_warning_msg(\n" - ] - } - ], + "outputs": [], "source": [ "# TODO: confirm, masking takes place in the model right? If yes, this is fine to take as the input" ] }, { "cell_type": "code", - "execution_count": 58, + "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "cls_embeddings = []\n", "names = []\n", - "for i in range(4):\n", + "for i in range(64):\n", " batch = val_dataset[i]\n", " name = val_dataset.aligndata.iloc[i].name\n", " batch = torch.tensor(batch).unsqueeze(0) \n", @@ -199,45 +190,20 @@ }, { "cell_type": "code", - "execution_count": 61, + "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "Timestamp('2010-11-01 00:36:00')" + "(64, 2)" ] }, - "execution_count": 61, + "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], - "source": [ - "name" - ] - }, - { - "cell_type": "code", - "execution_count": 60, - "metadata": {}, - "outputs": [ - { - "ename": "ValueError", - "evalue": "perplexity must be less than n_samples", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)", - "Cell \u001b[0;32mIn[60], line 8\u001b[0m\n\u001b[1;32m 5\u001b[0m tsne \u001b[38;5;241m=\u001b[39m TSNE(n_components\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m2\u001b[39m, random_state\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m0\u001b[39m)\n\u001b[1;32m 7\u001b[0m cls_embeddings_np \u001b[38;5;241m=\u001b[39m cls_embeddings\u001b[38;5;241m.\u001b[39mnumpy()\n\u001b[0;32m----> 8\u001b[0m cls_embeddings_tsne \u001b[38;5;241m=\u001b[39m \u001b[43mtsne\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfit_transform\u001b[49m\u001b[43m(\u001b[49m\u001b[43mcls_embeddings_np\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 10\u001b[0m cls_embeddings_tsne\u001b[38;5;241m.\u001b[39mshape\n", - "File \u001b[0;32m/opt/conda/envs/sdofm/lib/python3.10/site-packages/sklearn/utils/_set_output.py:295\u001b[0m, in \u001b[0;36m_wrap_method_output..wrapped\u001b[0;34m(self, X, *args, **kwargs)\u001b[0m\n\u001b[1;32m 293\u001b[0m \u001b[38;5;129m@wraps\u001b[39m(f)\n\u001b[1;32m 294\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mwrapped\u001b[39m(\u001b[38;5;28mself\u001b[39m, X, \u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[0;32m--> 295\u001b[0m data_to_wrap \u001b[38;5;241m=\u001b[39m \u001b[43mf\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mX\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 296\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(data_to_wrap, \u001b[38;5;28mtuple\u001b[39m):\n\u001b[1;32m 297\u001b[0m \u001b[38;5;66;03m# only wrap the first output for cross decomposition\u001b[39;00m\n\u001b[1;32m 298\u001b[0m return_tuple \u001b[38;5;241m=\u001b[39m (\n\u001b[1;32m 299\u001b[0m _wrap_data_with_container(method, data_to_wrap[\u001b[38;5;241m0\u001b[39m], X, \u001b[38;5;28mself\u001b[39m),\n\u001b[1;32m 300\u001b[0m \u001b[38;5;241m*\u001b[39mdata_to_wrap[\u001b[38;5;241m1\u001b[39m:],\n\u001b[1;32m 301\u001b[0m )\n", - "File \u001b[0;32m/opt/conda/envs/sdofm/lib/python3.10/site-packages/sklearn/base.py:1474\u001b[0m, in \u001b[0;36m_fit_context..decorator..wrapper\u001b[0;34m(estimator, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1467\u001b[0m estimator\u001b[38;5;241m.\u001b[39m_validate_params()\n\u001b[1;32m 1469\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m config_context(\n\u001b[1;32m 1470\u001b[0m skip_parameter_validation\u001b[38;5;241m=\u001b[39m(\n\u001b[1;32m 1471\u001b[0m prefer_skip_nested_validation \u001b[38;5;129;01mor\u001b[39;00m global_skip_validation\n\u001b[1;32m 1472\u001b[0m )\n\u001b[1;32m 1473\u001b[0m ):\n\u001b[0;32m-> 1474\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfit_method\u001b[49m\u001b[43m(\u001b[49m\u001b[43mestimator\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m/opt/conda/envs/sdofm/lib/python3.10/site-packages/sklearn/manifold/_t_sne.py:1135\u001b[0m, in \u001b[0;36mTSNE.fit_transform\u001b[0;34m(self, X, y)\u001b[0m\n\u001b[1;32m 1110\u001b[0m \u001b[38;5;129m@_fit_context\u001b[39m(\n\u001b[1;32m 1111\u001b[0m \u001b[38;5;66;03m# TSNE.metric is not validated yet\u001b[39;00m\n\u001b[1;32m 1112\u001b[0m prefer_skip_nested_validation\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m\n\u001b[1;32m 1113\u001b[0m )\n\u001b[1;32m 1114\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mfit_transform\u001b[39m(\u001b[38;5;28mself\u001b[39m, X, y\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mNone\u001b[39;00m):\n\u001b[1;32m 1115\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"Fit X into an embedded space and return that transformed output.\u001b[39;00m\n\u001b[1;32m 1116\u001b[0m \n\u001b[1;32m 1117\u001b[0m \u001b[38;5;124;03m Parameters\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 1133\u001b[0m \u001b[38;5;124;03m Embedding of the training data in low-dimensional space.\u001b[39;00m\n\u001b[1;32m 1134\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n\u001b[0;32m-> 1135\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_check_params_vs_input\u001b[49m\u001b[43m(\u001b[49m\u001b[43mX\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1136\u001b[0m embedding \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_fit(X)\n\u001b[1;32m 1137\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39membedding_ \u001b[38;5;241m=\u001b[39m embedding\n", - "File \u001b[0;32m/opt/conda/envs/sdofm/lib/python3.10/site-packages/sklearn/manifold/_t_sne.py:846\u001b[0m, in \u001b[0;36mTSNE._check_params_vs_input\u001b[0;34m(self, X)\u001b[0m\n\u001b[1;32m 844\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_check_params_vs_input\u001b[39m(\u001b[38;5;28mself\u001b[39m, X):\n\u001b[1;32m 845\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mperplexity \u001b[38;5;241m>\u001b[39m\u001b[38;5;241m=\u001b[39m X\u001b[38;5;241m.\u001b[39mshape[\u001b[38;5;241m0\u001b[39m]:\n\u001b[0;32m--> 846\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mperplexity must be less than n_samples\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", - "\u001b[0;31mValueError\u001b[0m: perplexity must be less than n_samples" - ] - } - ], "source": [ "# run TSNE on the cls embeddings\n", "from sklearn.manifold import TSNE\n", @@ -253,7 +219,7 @@ }, { "cell_type": "code", - "execution_count": 42, + "execution_count": 13, "metadata": {}, "outputs": [ { @@ -293,162 +259,236 @@ }, "data": [ { - "hovertemplate": "x=%{x}
y=%{y}", + "hovertemplate": "TSNE 1=%{x}
TSNE 2=%{y}
text=%{text}", "legendgroup": "", "marker": { "color": "#636efa", "symbol": "circle" }, - "mode": "markers", + "mode": "markers+text", "name": "", "orientation": "v", "showlegend": false, + "text": [ + "2010-11-01T00:00:00", + "2010-11-01T00:12:00", + "2010-11-01T00:24:00", + "2010-11-01T00:36:00", + "2010-11-01T00:48:00", + "2010-11-01T01:00:00", + "2010-11-01T01:12:00", + "2010-11-01T01:24:00", + "2010-11-01T01:36:00", + "2010-11-01T01:48:00", + "2010-11-01T02:00:00", + "2010-11-01T02:12:00", + "2010-11-01T02:24:00", + "2010-11-01T02:36:00", + "2010-11-01T02:48:00", + "2010-11-01T03:00:00", + "2010-11-01T03:12:00", + "2010-11-01T03:24:00", + "2010-11-01T03:36:00", + "2010-11-01T03:48:00", + "2010-11-01T04:00:00", + "2010-11-01T04:12:00", + "2010-11-01T04:24:00", + "2010-11-01T04:36:00", + "2010-11-01T04:48:00", + "2010-11-01T05:00:00", + "2010-11-01T05:12:00", + "2010-11-01T05:24:00", + "2010-11-01T05:36:00", + "2010-11-01T05:48:00", + "2010-11-01T06:00:00", + "2010-11-01T06:12:00", + "2010-11-01T06:24:00", + "2010-11-01T06:36:00", + "2010-11-01T06:48:00", + "2010-11-01T07:00:00", + "2010-11-01T07:12:00", + "2010-11-01T07:24:00", + "2010-11-01T07:36:00", + "2010-11-01T07:48:00", + "2010-11-01T08:00:00", + "2010-11-01T08:12:00", + "2010-11-01T08:24:00", + "2010-11-01T08:36:00", + "2010-11-01T08:48:00", + "2010-11-01T09:00:00", + "2010-11-01T09:12:00", + "2010-11-01T09:24:00", + "2010-11-01T09:36:00", + "2010-11-01T09:48:00", + "2010-11-01T10:00:00", + "2010-11-01T10:12:00", + "2010-11-01T10:24:00", + "2010-11-01T10:36:00", + "2010-11-01T10:48:00", + "2010-11-01T11:00:00", + "2010-11-01T11:12:00", + "2010-11-01T11:24:00", + "2010-11-01T11:36:00", + "2010-11-01T11:48:00", + "2010-11-01T12:00:00", + "2010-11-01T12:12:00", + "2010-11-01T12:24:00", + "2010-11-01T12:36:00" + ], + "textfont": { + "size": 8 + }, + "textposition": "top center", "type": "scatter", "x": [ - -1.4775547981262207, - -1.017698049545288, - -2.3421990871429443, - -0.9686383008956909, - -1.7953349351882935, - -2.230997323989868, - -0.9831541776657104, - -0.7328781485557556, - -1.3800632953643799, - 0.07800032943487167, - -1.5571708679199219, - -2.487609386444092, - -2.261706829071045, - -0.9377910494804382, - 0.7416231036186218, - -1.6795812845230103, - -1.3866422176361084, - -0.6841626167297363, - -2.334804058074951, - 0.6508134007453918, - -0.25103533267974854, - -1.0810067653656006, - -2.5791263580322266, - -0.7387914061546326, - 0.13009090721607208, - 0.9312425851821899, - -1.8240466117858887, - -2.1183955669403076, - -0.7637702226638794, - -1.4372256994247437, - -2.359626531600952, - 0.2970408499240875, - -2.491058588027954, - -1.3080605268478394, - 0.5572676658630371, - -1.7019890546798706, - -0.7962026596069336, - -0.8217019438743591, - -1.7293106317520142, - 0.9912880063056946, - -0.679885983467102, - 0.7837218642234802, - -2.402461528778076, - -1.6201286315917969, - -2.635576009750366, - -1.7331405878067017, - -0.3813764154911041, - -1.7368831634521484, - -1.6549166440963745, - -2.138592004776001, - 0.16142168641090393, - -0.18802887201309204, - -2.586073875427246, - -1.2996861934661865, - -0.9998782277107239, - -2.9353811740875244, - -1.5950214862823486, - -1.8743672370910645, - -2.4981508255004883, - -2.320810317993164, - 0.023177141323685646, - -1.3722947835922241, - -1.5299345254898071, - -0.12524938583374023, - 0.1468801647424698 + 2.5566623210906982, + 2.738436222076416, + 2.568350315093994, + 3.1591901779174805, + 3.270948648452759, + 2.8869283199310303, + 3.497849702835083, + 3.2450766563415527, + 3.484912395477295, + 3.8625094890594482, + 4.049447059631348, + 4.323691368103027, + 4.315840244293213, + 4.724743843078613, + 4.627059459686279, + 4.677896022796631, + 4.822544574737549, + 4.818457126617432, + 4.611992359161377, + 4.54832649230957, + 4.627342700958252, + 4.8484578132629395, + 5.244244575500488, + 5.632866382598877, + 5.781790733337402, + 6.125435829162598, + 6.56826114654541, + 6.892115592956543, + 7.076461315155029, + 7.033135890960693, + 7.3944621086120605, + 7.671503067016602, + 8.49928092956543, + 8.518630981445312, + 8.820515632629395, + 8.315404891967773, + 7.945209503173828, + 8.115118980407715, + 8.22454833984375, + 8.426251411437988, + 8.542154312133789, + 8.624923706054688, + 8.422052383422852, + 8.707727432250977, + 8.279617309570312, + 8.436790466308594, + 8.623080253601074, + 7.798712253570557, + 7.782196998596191, + 8.037912368774414, + 7.922629356384277, + 7.9688191413879395, + 7.213320255279541, + 7.004072189331055, + 6.367611885070801, + 6.437508583068848, + 6.027431011199951, + 5.864524841308594, + 5.728527545928955, + 5.332927703857422, + 5.372358798980713, + 5.271961212158203, + 5.358461856842041, + 5.462548732757568 ], "xaxis": "x", "y": [ - 4.6039042472839355, - 3.3738842010498047, - 6.149228096008301, - -0.5258415937423706, - 2.056642770767212, - 8.729084968566895, - 6.3997416496276855, - 0.3447766602039337, - 1.2026575803756714, - 2.714287281036377, - 2.4941611289978027, - 7.620736598968506, - 5.373752117156982, - 5.296590328216553, - 1.398061990737915, - 6.553701400756836, - 7.706721782684326, - 4.5697021484375, - 8.997517585754395, - 4.656033039093018, - 0.9822078943252563, - 2.5054409503936768, - 8.946307182312012, - 0.40767043828964233, - 1.265901803970337, - 2.361562967300415, - 6.890122413635254, - 6.36600923538208, - 4.107271671295166, - 0.20807480812072754, - 9.166582107543945, - 4.849384307861328, - 4.631340980529785, - 5.935306549072266, - 4.715306282043457, - 5.031026363372803, - 0.9169295430183411, - 3.456557512283325, - 3.681023359298706, - 2.1541361808776855, - 2.4078845977783203, - 1.9001367092132568, - 5.974546432495117, - 1.1110247373580933, - 7.897974967956543, - 0.6024311780929565, - 2.286405086517334, - 6.542756080627441, - 0.14914807677268982, - 8.140671730041504, - 0.6779770851135254, - -0.7267722487449646, - 7.456902027130127, - 6.278473854064941, - 2.9441044330596924, - 6.635173797607422, - 4.596971035003662, - 4.0217814445495605, - 6.33397102355957, - 9.130873680114746, - 4.633718490600586, - 5.098238945007324, - 0.6009475588798523, - 0.7665007710456848, - 1.6283544301986694 + -1.8961414098739624, + -1.9025202989578247, + -1.7482821941375732, + -1.575560212135315, + -1.248615026473999, + -0.6440263986587524, + -0.6478663682937622, + -0.5125745534896851, + -0.055819179862737656, + -0.18983420729637146, + 0.25126489996910095, + -0.02481245808303356, + 0.6518766283988953, + 1.2845197916030884, + 1.3000599145889282, + 1.8265562057495117, + 1.9414236545562744, + 2.6679351329803467, + 3.1173977851867676, + 3.3452675342559814, + 3.9025354385375977, + 4.198707103729248, + 4.09030818939209, + 4.234315395355225, + 4.199445724487305, + 4.26954984664917, + 4.247828960418701, + 3.8589251041412354, + 3.5774152278900146, + 3.267285108566284, + 3.0154905319213867, + 2.9258055686950684, + 2.8995590209960938, + 2.480369806289673, + 2.5585694313049316, + 2.0041935443878174, + 1.5368666648864746, + 1.3818058967590332, + 1.0307940244674683, + 0.6594323515892029, + 0.37725165486335754, + -0.6920632123947144, + -0.829949140548706, + -1.4270269870758057, + -1.368971586227417, + -1.932491421699524, + -2.2837164402008057, + -2.09818434715271, + -2.343635320663452, + -2.831899404525757, + -3.153186798095703, + -3.315528154373169, + -3.19529128074646, + -3.028510570526123, + -2.948838710784912, + -3.6706228256225586, + -3.627720355987549, + -3.6556167602539062, + -3.2960453033447266, + -3.423231363296509, + -2.9069814682006836, + -2.802427291870117, + -2.4083144664764404, + -2.109381914138794 ], "yaxis": "y" } ], "layout": { + "autosize": false, + "height": 800, "legend": { "tracegroupgap": 0 }, "margin": { - "t": 60 + "b": 0, + "l": 0, + "r": 0, + "t": 0 }, + "showlegend": false, "template": { "data": { "bar": [ @@ -1265,6 +1305,10 @@ } } }, + "title": { + "text": "TSNE of CLS embeddings" + }, + "width": 800, "xaxis": { "anchor": "y", "domain": [ @@ -1272,7 +1316,7 @@ 1 ], "title": { - "text": "x" + "text": "TSNE 1" } }, "yaxis": { @@ -1282,15 +1326,15 @@ 1 ], "title": { - "text": "y" + "text": "TSNE 2" } } } }, "text/html": [ - "