diff --git a/.gitignore b/.gitignore index 2e60f10959..78eaf9cdee 100644 --- a/.gitignore +++ b/.gitignore @@ -155,6 +155,14 @@ ml-20m/ *.model *.mml nohup.out +*.svg +*.html +*.js +*.css +*.tff +*.woff +*.woff2 +*.eot ##### kdd 2020 tutorial data folder examples/07_tutorials/KDD2020-tutorial/data_folder/ @@ -164,4 +172,4 @@ examples/07_tutorials/KDD2020-tutorial/data_folder/ *.sh tests/**/resources/ -reports/ \ No newline at end of file +reports/ diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 3b76affbe9..8d4295c3df 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -23,7 +23,7 @@ Here are the basic steps to get started with your first contribution. Please rea 5. Install development requirements. `pip install -r dev-requirements.txt` 6. Create a test that replicates the issue. 7. Make code changes. -8. Ensure unit tests pass and code style / formatting is consistent (see [wiki](https://github.com/Microsoft/Recommenders/wiki/Coding-Guidelines#python-and-docstrings-style) for more details). +8. Ensure that unit tests pass and code style / formatting is consistent (see [wiki](https://github.com/Microsoft/Recommenders/wiki/Coding-Guidelines#python-and-docstrings-style) for more details). In particular, make sure that there is a docstring for every function you add and that it conforms to the Google style. 9. Create a pull request against **staging** branch. Once the features included in a [milestone](https://github.com/microsoft/recommenders/milestones) are completed, we will merge staging into main. See the wiki for more detail about our [merge strategy](https://github.com/microsoft/recommenders/wiki/Strategy-to-merge-the-code-to-main-branch). diff --git a/docs/README.md b/docs/README.md index 7e45985c30..707978e0c2 100644 --- a/docs/README.md +++ b/docs/README.md @@ -13,4 +13,4 @@ To build the documentation as HTML: cd docs make html -To contribute to this repository, please follow these [guidelines](https://github.com/Microsoft/Recommenders/wiki/Coding-Guidelines). \ No newline at end of file +To contribute to this repository, please follow these [guidelines](https://github.com/Microsoft/Recommenders/wiki/Coding-Guidelines). See also the [Sphinx documentation](https://sublime-and-sphinx-guide.readthedocs.io/en/latest/index.html) for the syntax of docstrings. \ No newline at end of file diff --git a/docs/source/dataset.rst b/docs/source/dataset.rst index 0594ffcb23..d35a7d12bb 100644 --- a/docs/source/dataset.rst +++ b/docs/source/dataset.rst @@ -1,49 +1,66 @@ .. _dataset: Dataset module -************************** +############## -Recommendation datasets -=============================== +Recommendation datasets and related utilities -.. automodule:: reco_utils.dataset.movielens - :members: +Recommendation datasets +*********************** -.. automodule:: reco_utils.dataset.criteo - :members: +Amazon Reviews +============== .. automodule:: reco_utils.dataset.amazon_reviews :members: +Azure COVID-19 +============== + .. automodule:: reco_utils.dataset.covid_utils :members: +Criteo +====== + +.. automodule:: reco_utils.dataset.criteo + :members: + +MIND +==== + .. automodule:: reco_utils.dataset.mind :members: +MovieLens +========= + +.. automodule:: reco_utils.dataset.movielens + :members: + Download utilities -=============================== +****************** .. automodule:: reco_utils.dataset.download_utils :members: -Cosmos CLI -=============================== +Cosmos CLI utilities +********************* .. automodule:: reco_utils.dataset.cosmos_cli :members: -Pandas dataframe utils -=============================== +Pandas dataframe utilities +*************************** .. automodule:: reco_utils.dataset.pandas_df_utils :members: Splitter utilities -=============================== +****************** .. automodule:: reco_utils.dataset.python_splitters :members: @@ -56,14 +73,14 @@ Splitter utilities Sparse utilities -=============================== +**************** .. automodule:: reco_utils.dataset.sparse :members: Knowledge graph utilities -=============================== +************************* .. automodule:: reco_utils.dataset.wikidata :members: diff --git a/examples/01_prepare_data/mind_utils.ipynb b/examples/01_prepare_data/mind_utils.ipynb index e56321e9cc..1a7ba75701 100644 --- a/examples/01_prepare_data/mind_utils.ipynb +++ b/examples/01_prepare_data/mind_utils.ipynb @@ -55,7 +55,7 @@ "from tempfile import TemporaryDirectory\n", "from reco_utils.dataset.mind import (download_mind,\n", " extract_mind,\n", - " download_and_extract_globe,\n", + " download_and_extract_glove,\n", " load_glove_matrix,\n", " word_tokenize\n", " )\n", @@ -326,7 +326,7 @@ } ], "source": [ - "glove_path = download_and_extract_globe(data_path)" + "glove_path = download_and_extract_glove(data_path)" ] }, { diff --git a/examples/02_model_hybrid/fm_deep_dive.ipynb b/examples/02_model_hybrid/fm_deep_dive.ipynb index 36589a5d5e..eb1754bed2 100644 --- a/examples/02_model_hybrid/fm_deep_dive.ipynb +++ b/examples/02_model_hybrid/fm_deep_dive.ipynb @@ -231,7 +231,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "System version: 3.6.8 |Anaconda, Inc.| (default, Dec 30 2018, 01:22:34) \n", + "System version: 3.6.13 |Anaconda, Inc.| (default, Feb 23 2021, 21:15:04) \n", "[GCC 7.3.0]\n", "Xlearn version: 0.4.0\n" ] @@ -239,7 +239,6 @@ ], "source": [ "import sys\n", - "import os\n", "import papermill as pm\n", "import scrapbook as sb\n", @@ -254,11 +253,7 @@ "\n", "from reco_utils.common.constants import SEED\n", "from reco_utils.common.timer import Timer\n", - "from reco_utils.recommender.deeprec.deeprec_utils import (\n", - " download_deeprec_resources, prepare_hparams\n", - ")\n", - "from reco_utils.recommender.deeprec.models.xDeepFM import XDeepFMModel\n", - "from reco_utils.recommender.deeprec.io.iterator import FFMTextIterator\n", + "from reco_utils.dataset.download_utils import maybe_download, unzip_file\n", "from reco_utils.tuning.parameter_sweep import generate_param_grid\n", "from reco_utils.dataset.pandas_df_utils import LibffmConverter\n", "\n", @@ -414,7 +409,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 2, "metadata": { "tags": [ "parameters" @@ -422,37 +417,39 @@ }, "outputs": [], "source": [ - "# Parameters\n", - "YAML_FILE_NAME = \"xDeepFM.yaml\"\n", - "TRAIN_FILE_NAME = \"cretio_tiny_train\"\n", - "VALID_FILE_NAME = \"cretio_tiny_valid\"\n", - "TEST_FILE_NAME = \"cretio_tiny_test\"\n", - "MODEL_FILE_NAME = \"model.out\"\n", - "OUTPUT_FILE_NAME = \"output.txt\"\n", - "\n", + "# Model parameters\n", "LEARNING_RATE = 0.2\n", "LAMBDA = 0.002\n", + "EPOCH = 10\n", + "OPT_METHOD = \"sgd\" # options are \"sgd\", \"adagrad\" and \"ftrl\"\n", + "\n", "# The metrics for binary classification options are \"acc\", \"prec\", \"f1\" and \"auc\"\n", "# for regression, options are \"rmse\", \"mae\", \"mape\"\n", - "METRIC = \"auc\" \n", - "EPOCH = 10\n", - "OPT_METHOD = \"sgd\" # options are \"sgd\", \"adagrad\" and \"ftrl\"" + "METRIC = \"auc\" \n" ] }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "100%|██████████| 10.3k/10.3k [00:01<00:00, 8.67kKB/s]\n" + "100%|██████████| 10.3k/10.3k [00:00<00:00, 55.9kKB/s]\n" ] } ], "source": [ + "# Paths\n", + "YAML_FILE_NAME = \"xDeepFM.yaml\"\n", + "TRAIN_FILE_NAME = \"cretio_tiny_train\"\n", + "VALID_FILE_NAME = \"cretio_tiny_valid\"\n", + "TEST_FILE_NAME = \"cretio_tiny_test\"\n", + "MODEL_FILE_NAME = \"model.out\"\n", + "OUTPUT_FILE_NAME = \"output.txt\"\n", + "\n", "tmpdir = TemporaryDirectory()\n", "\n", "data_path = tmpdir.name\n", @@ -463,8 +460,9 @@ "model_file = os.path.join(data_path, MODEL_FILE_NAME)\n", "output_file = os.path.join(data_path, OUTPUT_FILE_NAME)\n", "\n", - "if not os.path.exists(yaml_file):\n", - " download_deeprec_resources(r'https://recodatasets.z20.web.core.windows.net/deeprec/', data_path, 'xdeepfmresources.zip')" + "assets_url = \"https://recodatasets.z20.web.core.windows.net/deeprec/xdeepfmresources.zip\"\n", + "assets_file = maybe_download(assets_url, work_directory=data_path)\n", + "unzip_file(assets_file, data_path)" ] }, { @@ -483,9 +481,17 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 11, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Training time: 14.4424\n" + ] + } + ], "source": [ "# Training task\n", "ffm_model = xl.create_ffm() # Use field-aware factorization machine (ffm)\n", @@ -511,7 +517,23 @@ "# The trained model will be stored in model.out\n", "with Timer() as time_train:\n", " ffm_model.fit(param, model_file)\n", - "\n", + "print(f\"Training time: {time_train}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Prediction time: 0.6435\n" + ] + } + ], + "source": [ "# Prediction task\n", "ffm_model.setTest(test_file) # Set the path of test dataset\n", "ffm_model.setSigmoid() # Convert output to 0-1\n", @@ -519,7 +541,8 @@ "# Start to predict\n", "# The output result will be stored in output.txt\n", "with Timer() as time_predict:\n", - " ffm_model.predict(model_file, output_file)" + " ffm_model.predict(model_file, output_file)\n", + "print(f\"Prediction time: {time_predict}\")" ] }, { @@ -531,7 +554,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 13, "metadata": {}, "outputs": [], "source": [ @@ -549,16 +572,16 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 14, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "0.7498803439718372" + "0.7485411618010794" ] }, - "execution_count": 8, + "execution_count": 14, "metadata": {}, "output_type": "execute_result" } @@ -586,23 +609,6 @@ "sb.glue('auc_score', auc_score)" ] }, - { - "cell_type": "code", - "execution_count": 10, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Training takes 10.77s and predicting takes 0.93s.\n" - ] - } - ], - "source": [ - "print('Training takes {0:.2f}s and predicting takes {1:.2f}s.'.format(time_train.interval, time_predict.interval))" - ] - }, { "cell_type": "markdown", "metadata": {}, @@ -641,7 +647,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 15, "metadata": {}, "outputs": [], "source": [ @@ -655,7 +661,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 16, "metadata": {}, "outputs": [], "source": [ @@ -686,14 +692,14 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 17, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Tuning by grid search takes 4.6 min\n" + "Tuning by grid search takes 4.2 min\n" ] } ], @@ -703,7 +709,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 18, "metadata": {}, "outputs": [ { @@ -740,21 +746,21 @@ " \n", " \n", " \n", - " 0.0001\n", - " 0.5482\n", + " 0.0001\n", + " 0.5481\n", " 0.6122\n", " 0.7210\n", " \n", " \n", - " 0.0010\n", - " 0.5456\n", - " 0.6101\n", - " 0.7246\n", + " 0.0010\n", + " 0.5454\n", + " 0.6103\n", + " 0.7245\n", " \n", " \n", - " 0.0100\n", - " 0.5406\n", - " 0.6147\n", + " 0.0100\n", + " 0.5405\n", + " 0.6150\n", " 0.7238\n", " \n", " \n", @@ -764,12 +770,12 @@ "text/plain": [ "Lambda 0.001 0.010 0.100\n", "LR \n", - "0.0001 0.5482 0.6122 0.7210\n", - "0.0010 0.5456 0.6101 0.7246\n", - "0.0100 0.5406 0.6147 0.7238" + "0.0001 0.5481 0.6122 0.7210\n", + "0.0010 0.5454 0.6103 0.7245\n", + "0.0100 0.5405 0.6150 0.7238" ] }, - "execution_count": 14, + "execution_count": 18, "metadata": {}, "output_type": "execute_result" } @@ -788,30 +794,988 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 21, "metadata": {}, "outputs": [ { "data": { + "application/javascript": [ + "/* Put everything inside the global mpl namespace */\n", + "/* global mpl */\n", + "window.mpl = {};\n", + "\n", + "mpl.get_websocket_type = function () {\n", + " if (typeof WebSocket !== 'undefined') {\n", + " return WebSocket;\n", + " } else if (typeof MozWebSocket !== 'undefined') {\n", + " return MozWebSocket;\n", + " } else {\n", + " alert(\n", + " 'Your browser does not have WebSocket support. ' +\n", + " 'Please try Chrome, Safari or Firefox ≥ 6. ' +\n", + " 'Firefox 4 and 5 are also supported but you ' +\n", + " 'have to enable WebSockets in about:config.'\n", + " );\n", + " }\n", + "};\n", + "\n", + "mpl.figure = function (figure_id, websocket, ondownload, parent_element) {\n", + " this.id = figure_id;\n", + "\n", + " this.ws = websocket;\n", + "\n", + " this.supports_binary = this.ws.binaryType !== undefined;\n", + "\n", + " if (!this.supports_binary) {\n", + " var warnings = document.getElementById('mpl-warnings');\n", + " if (warnings) {\n", + " warnings.style.display = 'block';\n", + " warnings.textContent =\n", + " 'This browser does not support binary websocket messages. ' +\n", + " 'Performance may be slow.';\n", + " }\n", + " }\n", + "\n", + " this.imageObj = new Image();\n", + "\n", + " this.context = undefined;\n", + " this.message = undefined;\n", + " this.canvas = undefined;\n", + " this.rubberband_canvas = undefined;\n", + " this.rubberband_context = undefined;\n", + " this.format_dropdown = undefined;\n", + "\n", + " this.image_mode = 'full';\n", + "\n", + " this.root = document.createElement('div');\n", + " this.root.setAttribute('style', 'display: inline-block');\n", + " this._root_extra_style(this.root);\n", + "\n", + " parent_element.appendChild(this.root);\n", + "\n", + " this._init_header(this);\n", + " this._init_canvas(this);\n", + " this._init_toolbar(this);\n", + "\n", + " var fig = this;\n", + "\n", + " this.waiting = false;\n", + "\n", + " this.ws.onopen = function () {\n", + " fig.send_message('supports_binary', { value: fig.supports_binary });\n", + " fig.send_message('send_image_mode', {});\n", + " if (fig.ratio !== 1) {\n", + " fig.send_message('set_dpi_ratio', { dpi_ratio: fig.ratio });\n", + " }\n", + " fig.send_message('refresh', {});\n", + " };\n", + "\n", + " this.imageObj.onload = function () {\n", + " if (fig.image_mode === 'full') {\n", + " // Full images could contain transparency (where diff images\n", + " // almost always do), so we need to clear the canvas so that\n", + " // there is no ghosting.\n", + " fig.context.clearRect(0, 0, fig.canvas.width, fig.canvas.height);\n", + " }\n", + " fig.context.drawImage(fig.imageObj, 0, 0);\n", + " };\n", + "\n", + " this.imageObj.onunload = function () {\n", + " fig.ws.close();\n", + " };\n", + "\n", + " this.ws.onmessage = this._make_on_message_function(this);\n", + "\n", + " this.ondownload = ondownload;\n", + "};\n", + "\n", + "mpl.figure.prototype._init_header = function () {\n", + " var titlebar = document.createElement('div');\n", + " titlebar.classList =\n", + " 'ui-dialog-titlebar ui-widget-header ui-corner-all ui-helper-clearfix';\n", + " var titletext = document.createElement('div');\n", + " titletext.classList = 'ui-dialog-title';\n", + " titletext.setAttribute(\n", + " 'style',\n", + " 'width: 100%; text-align: center; padding: 3px;'\n", + " );\n", + " titlebar.appendChild(titletext);\n", + " this.root.appendChild(titlebar);\n", + " this.header = titletext;\n", + "};\n", + "\n", + "mpl.figure.prototype._canvas_extra_style = function (_canvas_div) {};\n", + "\n", + "mpl.figure.prototype._root_extra_style = function (_canvas_div) {};\n", + "\n", + "mpl.figure.prototype._init_canvas = function () {\n", + " var fig = this;\n", + "\n", + " var canvas_div = (this.canvas_div = document.createElement('div'));\n", + " canvas_div.setAttribute(\n", + " 'style',\n", + " 'border: 1px solid #ddd;' +\n", + " 'box-sizing: content-box;' +\n", + " 'clear: both;' +\n", + " 'min-height: 1px;' +\n", + " 'min-width: 1px;' +\n", + " 'outline: 0;' +\n", + " 'overflow: hidden;' +\n", + " 'position: relative;' +\n", + " 'resize: both;'\n", + " );\n", + "\n", + " function on_keyboard_event_closure(name) {\n", + " return function (event) {\n", + " return fig.key_event(event, name);\n", + " };\n", + " }\n", + "\n", + " canvas_div.addEventListener(\n", + " 'keydown',\n", + " on_keyboard_event_closure('key_press')\n", + " );\n", + " canvas_div.addEventListener(\n", + " 'keyup',\n", + " on_keyboard_event_closure('key_release')\n", + " );\n", + "\n", + " this._canvas_extra_style(canvas_div);\n", + " this.root.appendChild(canvas_div);\n", + "\n", + " var canvas = (this.canvas = document.createElement('canvas'));\n", + " canvas.classList.add('mpl-canvas');\n", + " canvas.setAttribute('style', 'box-sizing: content-box;');\n", + "\n", + " this.context = canvas.getContext('2d');\n", + "\n", + " var backingStore =\n", + " this.context.backingStorePixelRatio ||\n", + " this.context.webkitBackingStorePixelRatio ||\n", + " this.context.mozBackingStorePixelRatio ||\n", + " this.context.msBackingStorePixelRatio ||\n", + " this.context.oBackingStorePixelRatio ||\n", + " this.context.backingStorePixelRatio ||\n", + " 1;\n", + "\n", + " this.ratio = (window.devicePixelRatio || 1) / backingStore;\n", + "\n", + " var rubberband_canvas = (this.rubberband_canvas = document.createElement(\n", + " 'canvas'\n", + " ));\n", + " rubberband_canvas.setAttribute(\n", + " 'style',\n", + " 'box-sizing: content-box; position: absolute; left: 0; top: 0; z-index: 1;'\n", + " );\n", + "\n", + " // Apply a ponyfill if ResizeObserver is not implemented by browser.\n", + " if (this.ResizeObserver === undefined) {\n", + " if (window.ResizeObserver !== undefined) {\n", + " this.ResizeObserver = window.ResizeObserver;\n", + " } else {\n", + " var obs = _JSXTOOLS_RESIZE_OBSERVER({});\n", + " this.ResizeObserver = obs.ResizeObserver;\n", + " }\n", + " }\n", + "\n", + " this.resizeObserverInstance = new this.ResizeObserver(function (entries) {\n", + " var nentries = entries.length;\n", + " for (var i = 0; i < nentries; i++) {\n", + " var entry = entries[i];\n", + " var width, height;\n", + " if (entry.contentBoxSize) {\n", + " if (entry.contentBoxSize instanceof Array) {\n", + " // Chrome 84 implements new version of spec.\n", + " width = entry.contentBoxSize[0].inlineSize;\n", + " height = entry.contentBoxSize[0].blockSize;\n", + " } else {\n", + " // Firefox implements old version of spec.\n", + " width = entry.contentBoxSize.inlineSize;\n", + " height = entry.contentBoxSize.blockSize;\n", + " }\n", + " } else {\n", + " // Chrome <84 implements even older version of spec.\n", + " width = entry.contentRect.width;\n", + " height = entry.contentRect.height;\n", + " }\n", + "\n", + " // Keep the size of the canvas and rubber band canvas in sync with\n", + " // the canvas container.\n", + " if (entry.devicePixelContentBoxSize) {\n", + " // Chrome 84 implements new version of spec.\n", + " canvas.setAttribute(\n", + " 'width',\n", + " entry.devicePixelContentBoxSize[0].inlineSize\n", + " );\n", + " canvas.setAttribute(\n", + " 'height',\n", + " entry.devicePixelContentBoxSize[0].blockSize\n", + " );\n", + " } else {\n", + " canvas.setAttribute('width', width * fig.ratio);\n", + " canvas.setAttribute('height', height * fig.ratio);\n", + " }\n", + " canvas.setAttribute(\n", + " 'style',\n", + " 'width: ' + width + 'px; height: ' + height + 'px;'\n", + " );\n", + "\n", + " rubberband_canvas.setAttribute('width', width);\n", + " rubberband_canvas.setAttribute('height', height);\n", + "\n", + " // And update the size in Python. We ignore the initial 0/0 size\n", + " // that occurs as the element is placed into the DOM, which should\n", + " // otherwise not happen due to the minimum size styling.\n", + " if (fig.ws.readyState == 1 && width != 0 && height != 0) {\n", + " fig.request_resize(width, height);\n", + " }\n", + " }\n", + " });\n", + " this.resizeObserverInstance.observe(canvas_div);\n", + "\n", + " function on_mouse_event_closure(name) {\n", + " return function (event) {\n", + " return fig.mouse_event(event, name);\n", + " };\n", + " }\n", + "\n", + " rubberband_canvas.addEventListener(\n", + " 'mousedown',\n", + " on_mouse_event_closure('button_press')\n", + " );\n", + " rubberband_canvas.addEventListener(\n", + " 'mouseup',\n", + " on_mouse_event_closure('button_release')\n", + " );\n", + " // Throttle sequential mouse events to 1 every 20ms.\n", + " rubberband_canvas.addEventListener(\n", + " 'mousemove',\n", + " on_mouse_event_closure('motion_notify')\n", + " );\n", + "\n", + " rubberband_canvas.addEventListener(\n", + " 'mouseenter',\n", + " on_mouse_event_closure('figure_enter')\n", + " );\n", + " rubberband_canvas.addEventListener(\n", + " 'mouseleave',\n", + " on_mouse_event_closure('figure_leave')\n", + " );\n", + "\n", + " canvas_div.addEventListener('wheel', function (event) {\n", + " if (event.deltaY < 0) {\n", + " event.step = 1;\n", + " } else {\n", + " event.step = -1;\n", + " }\n", + " on_mouse_event_closure('scroll')(event);\n", + " });\n", + "\n", + " canvas_div.appendChild(canvas);\n", + " canvas_div.appendChild(rubberband_canvas);\n", + "\n", + " this.rubberband_context = rubberband_canvas.getContext('2d');\n", + " this.rubberband_context.strokeStyle = '#000000';\n", + "\n", + " this._resize_canvas = function (width, height, forward) {\n", + " if (forward) {\n", + " canvas_div.style.width = width + 'px';\n", + " canvas_div.style.height = height + 'px';\n", + " }\n", + " };\n", + "\n", + " // Disable right mouse context menu.\n", + " this.rubberband_canvas.addEventListener('contextmenu', function (_e) {\n", + " event.preventDefault();\n", + " return false;\n", + " });\n", + "\n", + " function set_focus() {\n", + " canvas.focus();\n", + " canvas_div.focus();\n", + " }\n", + "\n", + " window.setTimeout(set_focus, 100);\n", + "};\n", + "\n", + "mpl.figure.prototype._init_toolbar = function () {\n", + " var fig = this;\n", + "\n", + " var toolbar = document.createElement('div');\n", + " toolbar.classList = 'mpl-toolbar';\n", + " this.root.appendChild(toolbar);\n", + "\n", + " function on_click_closure(name) {\n", + " return function (_event) {\n", + " return fig.toolbar_button_onclick(name);\n", + " };\n", + " }\n", + "\n", + " function on_mouseover_closure(tooltip) {\n", + " return function (event) {\n", + " if (!event.currentTarget.disabled) {\n", + " return fig.toolbar_button_onmouseover(tooltip);\n", + " }\n", + " };\n", + " }\n", + "\n", + " fig.buttons = {};\n", + " var buttonGroup = document.createElement('div');\n", + " buttonGroup.classList = 'mpl-button-group';\n", + " for (var toolbar_ind in mpl.toolbar_items) {\n", + " var name = mpl.toolbar_items[toolbar_ind][0];\n", + " var tooltip = mpl.toolbar_items[toolbar_ind][1];\n", + " var image = mpl.toolbar_items[toolbar_ind][2];\n", + " var method_name = mpl.toolbar_items[toolbar_ind][3];\n", + "\n", + " if (!name) {\n", + " /* Instead of a spacer, we start a new button group. */\n", + " if (buttonGroup.hasChildNodes()) {\n", + " toolbar.appendChild(buttonGroup);\n", + " }\n", + " buttonGroup = document.createElement('div');\n", + " buttonGroup.classList = 'mpl-button-group';\n", + " continue;\n", + " }\n", + "\n", + " var button = (fig.buttons[name] = document.createElement('button'));\n", + " button.classList = 'mpl-widget';\n", + " button.setAttribute('role', 'button');\n", + " button.setAttribute('aria-disabled', 'false');\n", + " button.addEventListener('click', on_click_closure(method_name));\n", + " button.addEventListener('mouseover', on_mouseover_closure(tooltip));\n", + "\n", + " var icon_img = document.createElement('img');\n", + " icon_img.src = '_images/' + image + '.png';\n", + " icon_img.srcset = '_images/' + image + '_large.png 2x';\n", + " icon_img.alt = tooltip;\n", + " button.appendChild(icon_img);\n", + "\n", + " buttonGroup.appendChild(button);\n", + " }\n", + "\n", + " if (buttonGroup.hasChildNodes()) {\n", + " toolbar.appendChild(buttonGroup);\n", + " }\n", + "\n", + " var fmt_picker = document.createElement('select');\n", + " fmt_picker.classList = 'mpl-widget';\n", + " toolbar.appendChild(fmt_picker);\n", + " this.format_dropdown = fmt_picker;\n", + "\n", + " for (var ind in mpl.extensions) {\n", + " var fmt = mpl.extensions[ind];\n", + " var option = document.createElement('option');\n", + " option.selected = fmt === mpl.default_extension;\n", + " option.innerHTML = fmt;\n", + " fmt_picker.appendChild(option);\n", + " }\n", + "\n", + " var status_bar = document.createElement('span');\n", + " status_bar.classList = 'mpl-message';\n", + " toolbar.appendChild(status_bar);\n", + " this.message = status_bar;\n", + "};\n", + "\n", + "mpl.figure.prototype.request_resize = function (x_pixels, y_pixels) {\n", + " // Request matplotlib to resize the figure. Matplotlib will then trigger a resize in the client,\n", + " // which will in turn request a refresh of the image.\n", + " this.send_message('resize', { width: x_pixels, height: y_pixels });\n", + "};\n", + "\n", + "mpl.figure.prototype.send_message = function (type, properties) {\n", + " properties['type'] = type;\n", + " properties['figure_id'] = this.id;\n", + " this.ws.send(JSON.stringify(properties));\n", + "};\n", + "\n", + "mpl.figure.prototype.send_draw_message = function () {\n", + " if (!this.waiting) {\n", + " this.waiting = true;\n", + " this.ws.send(JSON.stringify({ type: 'draw', figure_id: this.id }));\n", + " }\n", + "};\n", + "\n", + "mpl.figure.prototype.handle_save = function (fig, _msg) {\n", + " var format_dropdown = fig.format_dropdown;\n", + " var format = format_dropdown.options[format_dropdown.selectedIndex].value;\n", + " fig.ondownload(fig, format);\n", + "};\n", + "\n", + "mpl.figure.prototype.handle_resize = function (fig, msg) {\n", + " var size = msg['size'];\n", + " if (size[0] !== fig.canvas.width || size[1] !== fig.canvas.height) {\n", + " fig._resize_canvas(size[0], size[1], msg['forward']);\n", + " fig.send_message('refresh', {});\n", + " }\n", + "};\n", + "\n", + "mpl.figure.prototype.handle_rubberband = function (fig, msg) {\n", + " var x0 = msg['x0'] / fig.ratio;\n", + " var y0 = (fig.canvas.height - msg['y0']) / fig.ratio;\n", + " var x1 = msg['x1'] / fig.ratio;\n", + " var y1 = (fig.canvas.height - msg['y1']) / fig.ratio;\n", + " x0 = Math.floor(x0) + 0.5;\n", + " y0 = Math.floor(y0) + 0.5;\n", + " x1 = Math.floor(x1) + 0.5;\n", + " y1 = Math.floor(y1) + 0.5;\n", + " var min_x = Math.min(x0, x1);\n", + " var min_y = Math.min(y0, y1);\n", + " var width = Math.abs(x1 - x0);\n", + " var height = Math.abs(y1 - y0);\n", + "\n", + " fig.rubberband_context.clearRect(\n", + " 0,\n", + " 0,\n", + " fig.canvas.width / fig.ratio,\n", + " fig.canvas.height / fig.ratio\n", + " );\n", + "\n", + " fig.rubberband_context.strokeRect(min_x, min_y, width, height);\n", + "};\n", + "\n", + "mpl.figure.prototype.handle_figure_label = function (fig, msg) {\n", + " // Updates the figure title.\n", + " fig.header.textContent = msg['label'];\n", + "};\n", + "\n", + "mpl.figure.prototype.handle_cursor = function (fig, msg) {\n", + " var cursor = msg['cursor'];\n", + " switch (cursor) {\n", + " case 0:\n", + " cursor = 'pointer';\n", + " break;\n", + " case 1:\n", + " cursor = 'default';\n", + " break;\n", + " case 2:\n", + " cursor = 'crosshair';\n", + " break;\n", + " case 3:\n", + " cursor = 'move';\n", + " break;\n", + " }\n", + " fig.rubberband_canvas.style.cursor = cursor;\n", + "};\n", + "\n", + "mpl.figure.prototype.handle_message = function (fig, msg) {\n", + " fig.message.textContent = msg['message'];\n", + "};\n", + "\n", + "mpl.figure.prototype.handle_draw = function (fig, _msg) {\n", + " // Request the server to send over a new figure.\n", + " fig.send_draw_message();\n", + "};\n", + "\n", + "mpl.figure.prototype.handle_image_mode = function (fig, msg) {\n", + " fig.image_mode = msg['mode'];\n", + "};\n", + "\n", + "mpl.figure.prototype.handle_history_buttons = function (fig, msg) {\n", + " for (var key in msg) {\n", + " if (!(key in fig.buttons)) {\n", + " continue;\n", + " }\n", + " fig.buttons[key].disabled = !msg[key];\n", + " fig.buttons[key].setAttribute('aria-disabled', !msg[key]);\n", + " }\n", + "};\n", + "\n", + "mpl.figure.prototype.handle_navigate_mode = function (fig, msg) {\n", + " if (msg['mode'] === 'PAN') {\n", + " fig.buttons['Pan'].classList.add('active');\n", + " fig.buttons['Zoom'].classList.remove('active');\n", + " } else if (msg['mode'] === 'ZOOM') {\n", + " fig.buttons['Pan'].classList.remove('active');\n", + " fig.buttons['Zoom'].classList.add('active');\n", + " } else {\n", + " fig.buttons['Pan'].classList.remove('active');\n", + " fig.buttons['Zoom'].classList.remove('active');\n", + " }\n", + "};\n", + "\n", + "mpl.figure.prototype.updated_canvas_event = function () {\n", + " // Called whenever the canvas gets updated.\n", + " this.send_message('ack', {});\n", + "};\n", + "\n", + "// A function to construct a web socket function for onmessage handling.\n", + "// Called in the figure constructor.\n", + "mpl.figure.prototype._make_on_message_function = function (fig) {\n", + " return function socket_on_message(evt) {\n", + " if (evt.data instanceof Blob) {\n", + " /* FIXME: We get \"Resource interpreted as Image but\n", + " * transferred with MIME type text/plain:\" errors on\n", + " * Chrome. But how to set the MIME type? It doesn't seem\n", + " * to be part of the websocket stream */\n", + " evt.data.type = 'image/png';\n", + "\n", + " /* Free the memory for the previous frames */\n", + " if (fig.imageObj.src) {\n", + " (window.URL || window.webkitURL).revokeObjectURL(\n", + " fig.imageObj.src\n", + " );\n", + " }\n", + "\n", + " fig.imageObj.src = (window.URL || window.webkitURL).createObjectURL(\n", + " evt.data\n", + " );\n", + " fig.updated_canvas_event();\n", + " fig.waiting = false;\n", + " return;\n", + " } else if (\n", + " typeof evt.data === 'string' &&\n", + " evt.data.slice(0, 21) === 'data:image/png;base64'\n", + " ) {\n", + " fig.imageObj.src = evt.data;\n", + " fig.updated_canvas_event();\n", + " fig.waiting = false;\n", + " return;\n", + " }\n", + "\n", + " var msg = JSON.parse(evt.data);\n", + " var msg_type = msg['type'];\n", + "\n", + " // Call the \"handle_{type}\" callback, which takes\n", + " // the figure and JSON message as its only arguments.\n", + " try {\n", + " var callback = fig['handle_' + msg_type];\n", + " } catch (e) {\n", + " console.log(\n", + " \"No handler for the '\" + msg_type + \"' message type: \",\n", + " msg\n", + " );\n", + " return;\n", + " }\n", + "\n", + " if (callback) {\n", + " try {\n", + " // console.log(\"Handling '\" + msg_type + \"' message: \", msg);\n", + " callback(fig, msg);\n", + " } catch (e) {\n", + " console.log(\n", + " \"Exception inside the 'handler_\" + msg_type + \"' callback:\",\n", + " e,\n", + " e.stack,\n", + " msg\n", + " );\n", + " }\n", + " }\n", + " };\n", + "};\n", + "\n", + "// from http://stackoverflow.com/questions/1114465/getting-mouse-location-in-canvas\n", + "mpl.findpos = function (e) {\n", + " //this section is from http://www.quirksmode.org/js/events_properties.html\n", + " var targ;\n", + " if (!e) {\n", + " e = window.event;\n", + " }\n", + " if (e.target) {\n", + " targ = e.target;\n", + " } else if (e.srcElement) {\n", + " targ = e.srcElement;\n", + " }\n", + " if (targ.nodeType === 3) {\n", + " // defeat Safari bug\n", + " targ = targ.parentNode;\n", + " }\n", + "\n", + " // pageX,Y are the mouse positions relative to the document\n", + " var boundingRect = targ.getBoundingClientRect();\n", + " var x = e.pageX - (boundingRect.left + document.body.scrollLeft);\n", + " var y = e.pageY - (boundingRect.top + document.body.scrollTop);\n", + "\n", + " return { x: x, y: y };\n", + "};\n", + "\n", + "/*\n", + " * return a copy of an object with only non-object keys\n", + " * we need this to avoid circular references\n", + " * http://stackoverflow.com/a/24161582/3208463\n", + " */\n", + "function simpleKeys(original) {\n", + " return Object.keys(original).reduce(function (obj, key) {\n", + " if (typeof original[key] !== 'object') {\n", + " obj[key] = original[key];\n", + " }\n", + " return obj;\n", + " }, {});\n", + "}\n", + "\n", + "mpl.figure.prototype.mouse_event = function (event, name) {\n", + " var canvas_pos = mpl.findpos(event);\n", + "\n", + " if (name === 'button_press') {\n", + " this.canvas.focus();\n", + " this.canvas_div.focus();\n", + " }\n", + "\n", + " var x = canvas_pos.x * this.ratio;\n", + " var y = canvas_pos.y * this.ratio;\n", + "\n", + " this.send_message(name, {\n", + " x: x,\n", + " y: y,\n", + " button: event.button,\n", + " step: event.step,\n", + " guiEvent: simpleKeys(event),\n", + " });\n", + "\n", + " /* This prevents the web browser from automatically changing to\n", + " * the text insertion cursor when the button is pressed. We want\n", + " * to control all of the cursor setting manually through the\n", + " * 'cursor' event from matplotlib */\n", + " event.preventDefault();\n", + " return false;\n", + "};\n", + "\n", + "mpl.figure.prototype._key_event_extra = function (_event, _name) {\n", + " // Handle any extra behaviour associated with a key event\n", + "};\n", + "\n", + "mpl.figure.prototype.key_event = function (event, name) {\n", + " // Prevent repeat events\n", + " if (name === 'key_press') {\n", + " if (event.which === this._key) {\n", + " return;\n", + " } else {\n", + " this._key = event.which;\n", + " }\n", + " }\n", + " if (name === 'key_release') {\n", + " this._key = null;\n", + " }\n", + "\n", + " var value = '';\n", + " if (event.ctrlKey && event.which !== 17) {\n", + " value += 'ctrl+';\n", + " }\n", + " if (event.altKey && event.which !== 18) {\n", + " value += 'alt+';\n", + " }\n", + " if (event.shiftKey && event.which !== 16) {\n", + " value += 'shift+';\n", + " }\n", + "\n", + " value += 'k';\n", + " value += event.which.toString();\n", + "\n", + " this._key_event_extra(event, name);\n", + "\n", + " this.send_message(name, { key: value, guiEvent: simpleKeys(event) });\n", + " return false;\n", + "};\n", + "\n", + "mpl.figure.prototype.toolbar_button_onclick = function (name) {\n", + " if (name === 'download') {\n", + " this.handle_save(this, null);\n", + " } else {\n", + " this.send_message('toolbar_button', { name: name });\n", + " }\n", + "};\n", + "\n", + "mpl.figure.prototype.toolbar_button_onmouseover = function (tooltip) {\n", + " this.message.textContent = tooltip;\n", + "};\n", + "\n", + "///////////////// REMAINING CONTENT GENERATED BY embed_js.py /////////////////\n", + "// prettier-ignore\n", + "var _JSXTOOLS_RESIZE_OBSERVER=function(A){var t,i=new WeakMap,n=new WeakMap,a=new WeakMap,r=new WeakMap,o=new Set;function s(e){if(!(this instanceof s))throw new TypeError(\"Constructor requires 'new' operator\");i.set(this,e)}function h(){throw new TypeError(\"Function is not a constructor\")}function c(e,t,i,n){e=0 in arguments?Number(arguments[0]):0,t=1 in arguments?Number(arguments[1]):0,i=2 in arguments?Number(arguments[2]):0,n=3 in arguments?Number(arguments[3]):0,this.right=(this.x=this.left=e)+(this.width=i),this.bottom=(this.y=this.top=t)+(this.height=n),Object.freeze(this)}function d(){t=requestAnimationFrame(d);var s=new WeakMap,p=new Set;o.forEach((function(t){r.get(t).forEach((function(i){var r=t instanceof window.SVGElement,o=a.get(t),d=r?0:parseFloat(o.paddingTop),f=r?0:parseFloat(o.paddingRight),l=r?0:parseFloat(o.paddingBottom),u=r?0:parseFloat(o.paddingLeft),g=r?0:parseFloat(o.borderTopWidth),m=r?0:parseFloat(o.borderRightWidth),w=r?0:parseFloat(o.borderBottomWidth),b=u+f,F=d+l,v=(r?0:parseFloat(o.borderLeftWidth))+m,W=g+w,y=r?0:t.offsetHeight-W-t.clientHeight,E=r?0:t.offsetWidth-v-t.clientWidth,R=b+v,z=F+W,M=r?t.width:parseFloat(o.width)-R-E,O=r?t.height:parseFloat(o.height)-z-y;if(n.has(t)){var k=n.get(t);if(k[0]===M&&k[1]===O)return}n.set(t,[M,O]);var S=Object.create(h.prototype);S.target=t,S.contentRect=new c(u,d,M,O),s.has(i)||(s.set(i,[]),p.add(i)),s.get(i).push(S)}))})),p.forEach((function(e){i.get(e).call(e,s.get(e),e)}))}return s.prototype.observe=function(i){if(i instanceof window.Element){r.has(i)||(r.set(i,new Set),o.add(i),a.set(i,window.getComputedStyle(i)));var n=r.get(i);n.has(this)||n.add(this),cancelAnimationFrame(t),t=requestAnimationFrame(d)}},s.prototype.unobserve=function(i){if(i instanceof window.Element&&r.has(i)){var n=r.get(i);n.has(this)&&(n.delete(this),n.size||(r.delete(i),o.delete(i))),n.size||r.delete(i),o.size||cancelAnimationFrame(t)}},A.DOMRectReadOnly=c,A.ResizeObserver=s,A.ResizeObserverEntry=h,A}; // eslint-disable-line\n", + "mpl.toolbar_items = [[\"Home\", \"Reset original view\", \"fa fa-home icon-home\", \"home\"], [\"Back\", \"Back to previous view\", \"fa fa-arrow-left icon-arrow-left\", \"back\"], [\"Forward\", \"Forward to next view\", \"fa fa-arrow-right icon-arrow-right\", \"forward\"], [\"\", \"\", \"\", \"\"], [\"Pan\", \"Left button pans, Right button zooms\\nx/y fixes axis, CTRL fixes aspect\", \"fa fa-arrows icon-move\", \"pan\"], [\"Zoom\", \"Zoom to rectangle\\nx/y fixes axis, CTRL fixes aspect\", \"fa fa-square-o icon-check-empty\", \"zoom\"], [\"\", \"\", \"\", \"\"], [\"Download\", \"Download plot\", \"fa fa-floppy-o icon-save\", \"download\"]];\n", + "\n", + "mpl.extensions = [\"eps\", \"jpeg\", \"pdf\", \"png\", \"ps\", \"raw\", \"svg\", \"tif\"];\n", + "\n", + "mpl.default_extension = \"png\";/* global mpl */\n", + "\n", + "var comm_websocket_adapter = function (comm) {\n", + " // Create a \"websocket\"-like object which calls the given IPython comm\n", + " // object with the appropriate methods. Currently this is a non binary\n", + " // socket, so there is still some room for performance tuning.\n", + " var ws = {};\n", + "\n", + " ws.close = function () {\n", + " comm.close();\n", + " };\n", + " ws.send = function (m) {\n", + " //console.log('sending', m);\n", + " comm.send(m);\n", + " };\n", + " // Register the callback with on_msg.\n", + " comm.on_msg(function (msg) {\n", + " //console.log('receiving', msg['content']['data'], msg);\n", + " // Pass the mpl event to the overridden (by mpl) onmessage function.\n", + " ws.onmessage(msg['content']['data']);\n", + " });\n", + " return ws;\n", + "};\n", + "\n", + "mpl.mpl_figure_comm = function (comm, msg) {\n", + " // This is the function which gets called when the mpl process\n", + " // starts-up an IPython Comm through the \"matplotlib\" channel.\n", + "\n", + " var id = msg.content.data.id;\n", + " // Get hold of the div created by the display call when the Comm\n", + " // socket was opened in Python.\n", + " var element = document.getElementById(id);\n", + " var ws_proxy = comm_websocket_adapter(comm);\n", + "\n", + " function ondownload(figure, _format) {\n", + " window.open(figure.canvas.toDataURL());\n", + " }\n", + "\n", + " var fig = new mpl.figure(id, ws_proxy, ondownload, element);\n", + "\n", + " // Call onopen now - mpl needs it, as it is assuming we've passed it a real\n", + " // web socket which is closed, not our websocket->open comm proxy.\n", + " ws_proxy.onopen();\n", + "\n", + " fig.parent_element = element;\n", + " fig.cell_info = mpl.find_output_cell(\"
\");\n", + " if (!fig.cell_info) {\n", + " console.error('Failed to find cell for figure', id, fig);\n", + " return;\n", + " }\n", + " fig.cell_info[0].output_area.element.on(\n", + " 'cleared',\n", + " { fig: fig },\n", + " fig._remove_fig_handler\n", + " );\n", + "};\n", + "\n", + "mpl.figure.prototype.handle_close = function (fig, msg) {\n", + " var width = fig.canvas.width / fig.ratio;\n", + " fig.cell_info[0].output_area.element.off(\n", + " 'cleared',\n", + " fig._remove_fig_handler\n", + " );\n", + " fig.resizeObserverInstance.unobserve(fig.canvas_div);\n", + "\n", + " // Update the output cell to use the data from the current canvas.\n", + " fig.push_to_output();\n", + " var dataURL = fig.canvas.toDataURL();\n", + " // Re-enable the keyboard manager in IPython - without this line, in FF,\n", + " // the notebook keyboard shortcuts fail.\n", + " IPython.keyboard_manager.enable();\n", + " fig.parent_element.innerHTML =\n", + " '';\n", + " fig.close_ws(fig, msg);\n", + "};\n", + "\n", + "mpl.figure.prototype.close_ws = function (fig, msg) {\n", + " fig.send_message('closing', msg);\n", + " // fig.ws.close()\n", + "};\n", + "\n", + "mpl.figure.prototype.push_to_output = function (_remove_interactive) {\n", + " // Turn the data on the canvas into data in the output cell.\n", + " var width = this.canvas.width / this.ratio;\n", + " var dataURL = this.canvas.toDataURL();\n", + " this.cell_info[1]['text/html'] =\n", + " '';\n", + "};\n", + "\n", + "mpl.figure.prototype.updated_canvas_event = function () {\n", + " // Tell IPython that the notebook contents must change.\n", + " IPython.notebook.set_dirty(true);\n", + " this.send_message('ack', {});\n", + " var fig = this;\n", + " // Wait a second, then push the new image to the DOM so\n", + " // that it is saved nicely (might be nice to debounce this).\n", + " setTimeout(function () {\n", + " fig.push_to_output();\n", + " }, 1000);\n", + "};\n", + "\n", + "mpl.figure.prototype._init_toolbar = function () {\n", + " var fig = this;\n", + "\n", + " var toolbar = document.createElement('div');\n", + " toolbar.classList = 'btn-toolbar';\n", + " this.root.appendChild(toolbar);\n", + "\n", + " function on_click_closure(name) {\n", + " return function (_event) {\n", + " return fig.toolbar_button_onclick(name);\n", + " };\n", + " }\n", + "\n", + " function on_mouseover_closure(tooltip) {\n", + " return function (event) {\n", + " if (!event.currentTarget.disabled) {\n", + " return fig.toolbar_button_onmouseover(tooltip);\n", + " }\n", + " };\n", + " }\n", + "\n", + " fig.buttons = {};\n", + " var buttonGroup = document.createElement('div');\n", + " buttonGroup.classList = 'btn-group';\n", + " var button;\n", + " for (var toolbar_ind in mpl.toolbar_items) {\n", + " var name = mpl.toolbar_items[toolbar_ind][0];\n", + " var tooltip = mpl.toolbar_items[toolbar_ind][1];\n", + " var image = mpl.toolbar_items[toolbar_ind][2];\n", + " var method_name = mpl.toolbar_items[toolbar_ind][3];\n", + "\n", + " if (!name) {\n", + " /* Instead of a spacer, we start a new button group. */\n", + " if (buttonGroup.hasChildNodes()) {\n", + " toolbar.appendChild(buttonGroup);\n", + " }\n", + " buttonGroup = document.createElement('div');\n", + " buttonGroup.classList = 'btn-group';\n", + " continue;\n", + " }\n", + "\n", + " button = fig.buttons[name] = document.createElement('button');\n", + " button.classList = 'btn btn-default';\n", + " button.href = '#';\n", + " button.title = name;\n", + " button.innerHTML = '';\n", + " button.addEventListener('click', on_click_closure(method_name));\n", + " button.addEventListener('mouseover', on_mouseover_closure(tooltip));\n", + " buttonGroup.appendChild(button);\n", + " }\n", + "\n", + " if (buttonGroup.hasChildNodes()) {\n", + " toolbar.appendChild(buttonGroup);\n", + " }\n", + "\n", + " // Add the status bar.\n", + " var status_bar = document.createElement('span');\n", + " status_bar.classList = 'mpl-message pull-right';\n", + " toolbar.appendChild(status_bar);\n", + " this.message = status_bar;\n", + "\n", + " // Add the close button to the window.\n", + " var buttongrp = document.createElement('div');\n", + " buttongrp.classList = 'btn-group inline pull-right';\n", + " button = document.createElement('button');\n", + " button.classList = 'btn btn-mini btn-primary';\n", + " button.href = '#';\n", + " button.title = 'Stop Interaction';\n", + " button.innerHTML = '';\n", + " button.addEventListener('click', function (_evt) {\n", + " fig.handle_close(fig, {});\n", + " });\n", + " button.addEventListener(\n", + " 'mouseover',\n", + " on_mouseover_closure('Stop Interaction')\n", + " );\n", + " buttongrp.appendChild(button);\n", + " var titlebar = this.root.querySelector('.ui-dialog-titlebar');\n", + " titlebar.insertBefore(buttongrp, titlebar.firstChild);\n", + "};\n", + "\n", + "mpl.figure.prototype._remove_fig_handler = function (event) {\n", + " var fig = event.data.fig;\n", + " if (event.target !== this) {\n", + " // Ignore bubbled events from children.\n", + " return;\n", + " }\n", + " fig.close_ws(fig, {});\n", + "};\n", + "\n", + "mpl.figure.prototype._root_extra_style = function (el) {\n", + " el.style.boxSizing = 'content-box'; // override notebook setting of border-box.\n", + "};\n", + "\n", + "mpl.figure.prototype._canvas_extra_style = function (el) {\n", + " // this is important to make the div 'focusable\n", + " el.setAttribute('tabindex', 0);\n", + " // reach out to IPython and tell the keyboard manager to turn it's self\n", + " // off when our div gets focus\n", + "\n", + " // location in version 3\n", + " if (IPython.notebook.keyboard_manager) {\n", + " IPython.notebook.keyboard_manager.register_events(el);\n", + " } else {\n", + " // location in version 2\n", + " IPython.keyboard_manager.register_events(el);\n", + " }\n", + "};\n", + "\n", + "mpl.figure.prototype._key_event_extra = function (event, _name) {\n", + " var manager = IPython.notebook.keyboard_manager;\n", + " if (!manager) {\n", + " manager = IPython.keyboard_manager;\n", + " }\n", + "\n", + " // Check for shift+enter\n", + " if (event.shiftKey && event.which === 13) {\n", + " this.canvas_div.blur();\n", + " // select the cell after this one\n", + " var index = IPython.notebook.find_cell_index(this.cell_info[0]);\n", + " IPython.notebook.select(index + 1);\n", + " }\n", + "};\n", + "\n", + "mpl.figure.prototype.handle_save = function (fig, _msg) {\n", + " fig.ondownload(fig, null);\n", + "};\n", + "\n", + "mpl.find_output_cell = function (html_output) {\n", + " // Return the cell and output element which can be found *uniquely* in the notebook.\n", + " // Note - this is a bit hacky, but it is done because the \"notebook_saving.Notebook\"\n", + " // IPython event is triggered only after the cells have been serialised, which for\n", + " // our purposes (turning an active figure into a static one), is too late.\n", + " var cells = IPython.notebook.get_cells();\n", + " var ncells = cells.length;\n", + " for (var i = 0; i < ncells; i++) {\n", + " var cell = cells[i];\n", + " if (cell.cell_type === 'code') {\n", + " for (var j = 0; j < cell.output_area.outputs.length; j++) {\n", + " var data = cell.output_area.outputs[j];\n", + " if (data.data) {\n", + " // IPython >= 3 moved mimebundle to data attribute of output\n", + " data = data.data;\n", + " }\n", + " if (data['text/html'] === html_output) {\n", + " return [cell, data, j];\n", + " }\n", + " }\n", + " }\n", + " }\n", + "};\n", + "\n", + "// Register the function which deals with the matplotlib target/channel.\n", + "// The kernel may be null if the page has been refreshed.\n", + "if (IPython.notebook.kernel !== null) {\n", + " IPython.notebook.kernel.comm_manager.register_target(\n", + " 'matplotlib',\n", + " mpl.mpl_figure_comm\n", + " );\n", + "}\n" + ], "text/plain": [ - "" + "" ] }, - "execution_count": 15, "metadata": {}, - "output_type": "execute_result" + "output_type": "display_data" }, { "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXwAAAELCAYAAADawD2zAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAgAElEQVR4nO3dd3gVVeLG8e9JCCC9Ewi95IKKCgioKEWQIkXFhiIqoqCiP5qIoiIqNlxcG7jKiouuAq5KUzoCKyqCoFK99JIQSOgEFkg5vz9mCAkpBMIkwLyf55mHmTNnzpm5Q95Mzp0711hrERGRi19IXu+AiIjkDgW+iIhPKPBFRHxCgS8i4hMKfBERn1Dgi4j4hAL/7LUDgsAG4JkM1j8IxAF/uNPDp6wvBkQDH6QquwdYCawAZgJl3PK3gL/c8klAiXNxAD53uvMHcBewBlgNfJmqfCawH/julPpfuG2uAsYCYW55N5xztwL4Gbgy57svJwQCgXaBQCAYCAQ2BAKBdOcyEAj8PRAI/OFO6wKBwH63/KpAIPBLIBBYHQgEVgQCgbtTbfOE254NBAJlTm3zQqXAPzuhwCigPXApTlBfmkG9icBV7vTPU9a9AixMtZwPeBdoCVyBEw5PuOvmAJe75euAZ8/FQfhYds5fbZzXuSlwGdAv1bq3gO4ZtPsFUAeoB1zCyV/ym4HmOOfvFeDjc3EQAoFAIN25DAQCac5lMBjsHwwGrwoGg1cB7wPfuquOAPcHg8HLcC4A3gkEAicupn4CWgNbc+Ewco0C/+w0xrky3AQcByYAt5zB9g2B8sDsVGXGnQq7/xYDdrjrZgOJ7vxioNLZ7rgA2Tt/j+AEyT53OTbVunnAoQzanQ5Yd1rCyfP0c6p2dP7OrcbAhmAwuCkYDGbnZ/EeYDxAMBhcFwwG17vzO3DOcVl3+fdgMLjFyx3PCwr8sxMBbE+1HOWWnep2nCv1r4HKblkIMBIYdErdBOAxnCGdHThXK59k0OZDwIyz3XEBsnf+It3pJ5yQbncG7Yfh/AUwM4N1PdH5O5ey+7NIIBCoClQHfshgXWMgP7DRg308b5jcfrSCMeZja22vTNb1AnoBXJK/bMMCYcVydd+y65bb2nNjqxvo+8QQAO7ueisNrr6CwU+9nFKnZKkSHI4/wvHjx+nR8x5u7XIzt3ToziO9u3PJJQV5750x3NOtC/Ub1OPpgS+RL18+vp48ln5PPs+WzdsYMfJFdu2KY+SI0SltDhz0GFc1qEf3ex7P9WM+E38rcU1e70KWqnVoTKUWV7BokDPKVuv2ppS5qiaLX/gspc5N/xpIcmISPzz6PoUrlKLjty/wbatnOH7wCADh19alXu+bmfPgyHTtXz+iJwlHjvHrsH+nKa9wXV2ue/VBvrvtFY7tj/fwCHOm++LBeb0L2TZr4c/8vPRPXnrqMQCmzVnIyr82MOTJnunqfjJ+Ert27023Lm7PPh4a8CLDBz/BlZdGplnX9t7HmPDhm5Qsfn5mUUbyV6pnMluXz4sOjTGlMlsF3JzZdtbaj3HHN0sWqXXePuRnR/ROIipVSFmuGBHOzpjYNHX27d2fMj/u04kMe/lpABo1voprr2tEz0e6UbhIIcLC8nM4/ghTpzgXg1s2bwNg8rfT6Tegd0obXe+9jTbtbuTWjhkNHcuZOBKzl8IVTv4XLRReiiM796WpczhmL3HLN2ATk4jfHseBjTEUqx7O7j83Zdl2/f63UbBUURYNHpumvGTdylw/4mFmdX/rvA77C035MqXZGbc7ZXlX3B7KlS6ZYd2ZC37iuf9Le+9E/OEj9BnyGk881DVd2F+MvBrSiQN+A5almn5zp3Ie9Zlrli9bQc2aValStRJhYWF0uaMDM6bPS1OnfPmyKfPtO7QiGHT+UuzVcyD16jbjysta8MKQN5g4fhIvvfgWMTt2EahTi9JlnCBqcWPTlG1atW5G3wG9uffu3vzvf0dz6SgvXnF/bqJY9XCKVC5LSFgoNW65hm1zlqeps3XWMipc57z3V6BkEYrVCOfQ1tiMmksReU8LIprXY/4ToyDVX86FK5am9Zh+LOz7Dw5u3nnuD8jHLq9Ti63RMUTF7CIhIYEZ83+ixXWN0tXbvD2ag4cOc+WlgZSyhIQE+r04gk5tmtO2+XW5udt5xpMrfJw3w1pZa7edusIYsz2D+heUpKQknh74Et9M/pTQ0FC++Pw//LV2Pc8+35c/lq9ixvR59H7sAdp1aEVSYiL79h2gz6NPZ9nmzp2xjHj9fb6f9SWJCYls37aDx91tRox8kQIF8jNp6r8A+G3pHwzoO9Trw7xo2aRkfnlhHO2+eBoTEsK6iQvZvy6aBk/dzu4/N7NtznKiF6ygUrN6dPnhTWxyMkuHj0+5Mu/wzQsUr1WBsMIF6br0PX58agzRC1fS9PUexEftptOUYQBsmbGUP96ZTP3+t1GgRBGue+1BAJITk5jaQefvXMgXGsqQJx/m0cHDSUpO5rb2N1KrWmU++HQClwVq0tIN/xk/LKJdy6YYc3K0Y+aCX1i2Yi37D8YzZdYCAIY/3Yc6tarzxbffM3biFPbs3c/tjwzkhsYNUoaNLmSejOEbY/oAi6y1f2aw7klr7funa+N8HtKRrJ3vY/iStQtpDF/Sy/UxfGvtqCzWnTbsRUTk3PNqSAdjTB2c+2EjcO5L3gFMtdau9apPERHJnCdv2hpjBuN8AMLgfABlqTs/3hiT2cfYRUTEQ15d4fcELrPWJqQuNMa8jfNckjc86ldERDLh1W2ZyUDFDMoruOtERCSXeXWF3w+YZ4xZz8mPPVcBanHygWAiIpKLvLpLZ6YxJhLnwUYROOP3UcBSa22SF32KiEjWPLtLh5NPDbQ4wzgn/hURkTzg1bN02gCjgfU4X/IBziNhaxljHrfWzs50YxER8YRXV/jvAq2ttVtSFxpjquM8M7yuR/2KiEgmvLpLJx/OmP2pojn5tW8iIpKLvLrCHwssNcZM4ORdOpWBrmT8pR4iIuIxr+7Sed0YMwXoDFzLybt0ullr13jRp4iIZM2zu3TcYF/jfhmKtdbuO902IiLiHa+epVPFGDPBGBML/AosMcbEumXVvOhTRESy5tWbthOBSUAFa21ta21tnMcqTMZ5qJqIiOQyrwK/jLV2YupP1Vprk6y1E4DSHvUpIiJZ8GoMf5kxZjQwjrR36TwA/O5RnyIikgWvAv9+nEckv8TJZ+lsB6ah2zJFRPKEV7dlHgc+dCcRETkPeDWGnyljTMfc7lNERPIg8IFGedCniIjv5cWXmL/oVZ8iIpI5fYm5iIhP6EvMRUR8Ql9iLiLiE/oScxERn9CXmIuI+ISXj0dOBhZ71b6IiJyZvLgPX0RE8oACX0TEJxT4IiI+ocAXEfEJBb6IiE8o8EVEfEKBLyLiEwp8ERGfUOCLiPiEAl9ExCcU+CIiPqHAFxHxCQW+iIhPKPBFRHxCgS8i4hMKfBERn1Dgi4j4hAJfRMQnFPgiIj6hwBcR8QnPvsQ8p44nJ+b1LshZSjB5vQeSE6Zg4bzeBfGIrvBFRHxCgS8i4hMKfBERn1Dgi4j4hAJfRMQnFPgiIj6hwBcR8QkFvoiITyjwRUR8QoEvIuITCnwREZ9Q4IuI+IQCX0TEJxT4IiI+ocAXEfEJBb6IiE8o8EVEfEKBLyLiEwp8ERGfUOCLiPiEAl9ExCcU+CIiPqHAFxHxCQW+iIhPKPBFRHxCgS8i4hMKfBERn1Dgi4j4hAJfRMQnFPgiIj6hwBcR8QkFvoiITyjwRUR8QoEvIuITCnwREZ9Q4IuI+IQCX0TEJxT4IiI+ocAXEfEJBb6IiE8o8EVEfEKBLyLiEwp8ERGfUOCLiPiEAl9ExCcU+CIiPqHAFxHxCQW+iIhPKPBFRHxCgS8i4hP58noHLlQ33dScEW8NJTQ0lHH/msjIkR+mWX/ffXcw/NVniYnZBcA//jGOcf+amLK+aNEiLP99LlOnzmLggBcBmDFzAuHhZTl69BgAnTt1Jy5uDwBdunRgyHP9sNayauVaevTomxuHedGq3OIKmg7rjgkNYe34Bfwxelq6OjU7NqFh/y5gLXvWbmPek6MBuPnzpylfvyY7l65jRo+RKfWLVi5L61F9KFiiCHGrtvBD3w9JTkiiQpMA173YndJ1KzO3zwdsmr40147zYrRo8W+88c4/SEpO5vZO7Xi4+11p1r/57kcsWb4CgKPHjrF3335+mfU1f63byCt/+4D4w0cICQ2h1/1dad+6eZptX3t7NJOmz2Hp3EkpZTPn/ZfRY/+NwRCoXYMRwwZ7f5AeOW3gG2NCgZLW2t3ucn7gQaC/tbaut7t3fgoJCeHtv79Mp473ER29kx9/nMr338/hr782pKn3zTffpYT5qYYOHciiH39NV/7QQ/34ffnKNGU1a1bjqUGP07rV7ezff5CyZUufu4PxIRNiuH74A3x37xscjtlLl+9eZuucZexbvyOlTvFq5anfpxOTu7zE8QNHKFi6WMq6P//xPfkuyc+l3W5M0+41z3ZlxT9nsnHqYm54rQd1urZgzefziI/ew/wBH3Fl75tz7RgvVklJSQwfOYox77xGeLky3P1wX1pe34Sa1aum1Bnct3fK/Bf/mcLa9RsBKFiwAK+98BRVK0cQG7eHu3o+SdMmDSlWtAgAq9au42D84TT9bd0ezT8/n8jnH46keLGi7Nm3PxeO0jtZDukYY7oCe4EVxpiFxpiWwCagPdAtF/bvvHT11VexaeNWtmzZTkJCAl9/PY2OHdtke/ur6l9O2XJlmDfvx2zV79GjKx999Bn79x8ESLnql7NT7qqaHNyyi0Pb4khOSGLj1MVUa9MwTZ2697Zk1bi5HD9wBICjew6mrIv+aTUJ8UfTtVux6aVs+n4JAOu+/pHqbZ02D0XtZu9f28Farw7JN1auXUeVShWpHFGBsLAw2rdqzg8/Ls60/vS5C7m5dQsAqlWpRNXKEQCUK1uaUiVLsG//AcD5RTJy1CcMfLxnmu2/njqTrl06UbxYUQBKlyzhwVHlntON4T8PNLTWVgT6AzOBJ621t1lrl59Nh8aYHmez3fmkYsXyREWfvBqMjo6hQsXy6erdemt7fv11Bv/+YjQRERUAMMbw+uvP89yQ1zJs+6N/vMUvi6cz+JknU8pq1a5B7VrVmTvva+YvmMRNNzXPcFvJnsLhJYnfsTdlOT5mL4XDS6apU7xGOCVqhHPrt0O5bcowKre4Iss2C5YswvGDR7BJyZm2KTkXG7eb8HJlU5bLlytDbCYXQDt27iI6ZidNGl6Zbt3KNUESEhKp7P5cfvnNNFpefw1ly5RKU2/r9mi2bo/mvkcHcu8j/Vi0+LdzeDS573SBf9xauwHADfjN1tpJp9nmdF7KbIUxppcx5jdjzG+JiYdy2I13jDHpyuwpV2/Tp8+lbp3radKkPfPn/8SYMc5Yb6/e3Zk9az7R0THp2njoob40btyOm1rfSdPrGnHvvV0AyJcvlJq1qtOubVcefOBJRo1+g+LFi6XbXrIpw/OXdjkkNJTi1cOZeterzH1iFM1HPEz+YoVy1KbkXEavaQYvPQAz5i6kTYvrCQ0NTVMet3svz778FsOH9CckJITYuD3Mnv8j997ROV0biUlJbI2K5tMP3mTES8/w4hvvcPBQ/Lk4lDxxujH8csaYAamWi6Retta+ndFGxpgVmbRngPSXwifb+xj4GKBwoWrn7Y9LdPROKkVUTFmOiKjAzpjYNHX27j051vfp2PG88orzRk+Txg24rmkjHunVncKFC5E/fxiH448wdOibxOxw3uCNjz/MV19NpeHVV/Lll98SHb2TpUt+JzExka1bo1i/bhM1a1Vj+bLMXmbJyuGYvRSpePJKrkiFUhzZtS9NnfiYvcT+voHkxCQObY9j/8YYilcPJ+7PTRm2eXTvIfIXK4QJDcEmJWfYpuRc+XJl2Bkbl7K8K3Y3Zctk/J7WjLkLeW5gnzRl8YcP8/igoTzZ6wGuvNx5C3Lt+o1si4rh5rsfAuDo0WO0v+shZnw1lvJly3DlZXUIy5ePShXDqValElujoqlXN+DREXrrdFf4Y4CiqabUy0Wy2K48cD/QKYPpgh+AXrbsT2rWqkbVqpUICwvjjjs68f33c9LUCQ8/+Wdnh443EQw6bxw99FA/6gSacmnd63luyGt8+eW3DB36JqGhoZQu7QwB5MuXj3btb2TNmnUAfDdtNs2aXQtA6dIlqVW7Ols2b8uNQ70oxf65ieLVwilauSwhYaHU7HwNW+akHaHcMnsZFa+9FHCGa0rUCOfg1tiMmkux4+c11OjQGIDIO25gy+yzGvWULFxeJ5JtUTuI2rGThIQEZsxbSMvrr0lXb/PWKA4eiueqy0/eV5KQkEDfZ1+hc7tWtL3xhpTy5tc1ZuG0L5n9zThmfzOOggULMOOrsQC0anYtS5b/CcC+/QfYsj2ayhUreHyU3snyCt9am9XwS78sNv0OKGKt/SOD7RZke+/OU0lJSQwcMJQpUz8jNDSUzz77irVr1/P8C/1Zvnwl07+fy2OP9eDmDq1JSkxi77799O71VJZtFiiQnylTPyMsXz5CQkNZMP8nPh07HoA5cxbSqtUN/LZsDslJSTw35PU0f0HImbFJySx6YRwd/v00JjSE4MSF7FsXzdUDbyduxWa2zlnO9gUrqNSsHnfNexObnMwvr47n2H7nT/lbvnmBEjUrEFa4IPcteY8Fg8YQtXAli1+fwE2jnqDxoDvZvWoLaycsAKDslTVoO6YfBYoXomrr+lw94Ha+av1MHr4CF658+UIZ0v8xeg94nqSkJG7r2IZaNarywZjPuKxOJC1vcMJ/+twFtG/dPM3w68wffmTZH6vYf+AQk6fPBeDV5wZQJ7Jmpv01bdKQn5csp3O3XoSGhDKwT09KXMDDqebUsedsb2jMNmttlXO8PynO5yEdydrI0tfn9S5IDvT8/eW83gXJgbAyNTJ5VyNnH7zKtFEA4/xqbQxEABbYASyxZ/sbRkREciQngZ9pcBtj2gCjgfVAtFtcCahljHncWjs7B/2KiMhZyDLwjTGHyDjYDXBJFpu+C7S21m45pb3qwHTAl5/QFRHJS6d707ZoDtqNyqA8Ggg7yzZFRCQHvHp42lhgqTFmArDdLasMdAU+8ahPERHJgieBb6193RgzBegMXIszBBQFdLPWrvGiTxERyZpnj0d2g32NMaaUs2j1sUMRkTzkyRegGGOqGGMmGGNigV+BJcaYWLesmhd9iohI1rz6xquJwCSggrW2trW2NlABmAxM8KhPERHJgleBX8ZaO9Fam3SiwFqbZK2dAOjbO0RE8oBXY/jLjDGjgXGkvUvnAeB3j/oUEZEseBX49wM9cZ59H4Fzl852YBq6LVNEJE94dVvmceBDdxIRkfOAV2P4mTLGdMztPkVEJA8CH2iUB32KiPieZx+8MsbUAW4h7eORp1prX/SqTxERyZxXH7wajHO/vQGWAEvd+fHGGH3Vj4hIHvDqCr8ncJm1NiF1oTHmbWA18IZH/YqISCa8GsNPBipmUF7BXSciIrnMqyv8fsA8Y8x6Tn7wqgpQC3jCoz5FRCQLXt2HP9MYE8nJ77Q98XjkpakftyAiIrnHy8cjJwOLvWpfRETOTF7chy8iInlAgS8i4hMKfBERn1Dgi4j4hAJfRMQnFPgiIj6hwBcR8QkFvoiITyjwRUR8QoEvIuITCnwREZ9Q4IuI+IQCX0TEJxT4IiI+ocAXEfEJBb6IiE8o8EVEfEKBLyLiEwp8ERGfUOCLiPiEAl9ExCcU+CIiPqHAFxHxCQW+iIhPKPBFRHxCgS8i4hMKfBERn8iX1zuQmWOJCXm9C3KWiiTl9R5IjiTrBF6sdIUvIuITCnwREZ9Q4IuI+IQCX0TEJxT4IiI+ocAXEfEJBb6IiE8o8EVEfEKBLyLiEwp8ERGfUOCLiPiEAl9ExCcU+CIiPqHAFxHxCQW+iIhPKPBFRHxCgS8i4hMKfBERn1Dgi4j4hAJfRMQnFPgiIj6hwBcR8QkFvoiITyjwRUR8QoEvIuITCnwREZ9Q4IuI+IQCX0TEJxT4IiI+ocAXEfEJBb6IiE8o8EVEfEKBLyLiEwp8ERGfUOCLiPiEAl9ExCcU+CIiPqHAFxHxCQW+iIhPKPBFRHxCgS8i4hMKfBERn1Dgi4j4hAJfRMQnFPgiIj6hwBcR8QkFvoiITyjwRUR8QoEvIuITCnwREZ/Il9c7cKFq26YFb7/9MqEhIYz9dDwj3hqVZv393e/izTeeJ3rHTgBGj/6UsZ+OT1lftGgRVq1YwOQpM+nb73kAGtSvxyef/J1LChZkxswf6D9gaEr9Po/34PHHe5CYmMiMGfN45tlXc+EoL14VWlxBo1e6Y0JC2DB+Aas/mJauTpVOTbhiYBewln1rtvFTn9EA3PjF05RpUJPYJetY8MDIdNtdPfx+at7djIm1Hwag4bBulG96KQD5CuanYJlifFW3t4dHd3Fb9Osy3nj3Y5KSk7m9Yxsevu/ONOvffG8MS35fAcDRo8fYu/8Av8yYyF/rN/HKyFHEH/4fISEh9Lr/Ltq3agbAC2+8y+q/1mMtVKtckVeH9KdQoUuI2RXLkFf/zqH4wyQlJdP/0Qdodm2jXD/mc0WBfxZCQkJ4791XaXfzPURFxbD4l+lM+242a9euT1Pvq/9MTQnzU700bBD//XFxmrJRH7zOY48NZvGvy/hu6ue0a9uSmbPm06L5dXTu1Jb6DVpz/PhxypYt7dmx+YEJMTR+7QHmdX2DIzF7aT/9ZaJmLePA+h0pdYpWL8/lT3Zi9i0vcfzAEQqULpaybs2H3xN6SX5q33djurZLXVGd/MUKpSlbNuyLlPnAQzdR8vJq5/6gfCIpKYnhb3/ImL8PJ7xsae5+pD8tmzahZvUqKXUG/98jKfNffD2Ntes3AlCwQAFee24AVStHELt7D3f17EfTxg0oVrQIg598hCKFnfM24v0xfPntdzx83518NG4ibVveQNfbbmbj5m089vQwZv/nwg18DemchcaN6rNx4xY2b95GQkICX301hc6d2mZ7+wb161G+fFnmzPlvSll4eDmKFivK4l+XAfD5F1/TuXM7AHr3vp8Rb43i+PHjAMTF7TmHR+M/pevX5NCWXcRviyM5IYktUxZTqW3DNHVqdWvJun/N5fiBIwAc23MwZd3ORatJjD+arl0TYmjwwj38PnxCpn1Xu/Vatkz+5Rwdif+sXLuOKhEVqFwxnLCwMNq3asYPixZnWn/6vIXc3Lo5ANWqRFC1cgQA5cqUplTJ4uzbfwAgJeyttRw9dhxjDADGGA4fcf4PHDp8mLJlSnl2bLkh1wPfGFMnt/s81ypGhLM96uTVYFR0DBUrhqer1+W2m1m+bA4TJ3xMpUoVAec/0FsjhjL4meFp6kZUDCc6KiZlOToqhgi3zdq1a3D99Y35edE0fpj7NVc3vNKLw/KNQuElObJjb8rykZi9FKpQMk2dYjXCKVojnDZThtJ22jAqtLjitO1G9mhD1Ozl/C92f4brC0eUpkjlcuxatDpnB+BjsXF7CC9XNmW5fNkyxO7O+AJox85YonfsokmD9Odu5ZogCYmJVI6okFL2/Gvv0PyW7mzeFsW9t3cE4PEe9/Ld7Pm06vIAjw8axpB+j57jI8pl1tpcnYBtWazrBfzmTr1ye9/OYLrTWvvPVMvdrbXvn1KntLW2gDv/qLX2B3f+CWvt0+7xPWit/cAtb2StnZtq+xustdPc+VXW2vestcZa29hau9mdz+vX4UKdsnP+vrPWTrLWhllrq1tro6y1JU6sf+qpp/7m1jlRv6K1dpG1Np+7HJ9Bv4Mz6EfTGUyRkZF3RkZG/jPVcvfIyMgMX9PIyMjBGa0rWLDgoMjIyGBkZOQ1GWwTGhkZOToyMrKHuzwgMjJyoDt/bWRk5JrIyMiQvH4dznbyZAzfGPNeZquAEpltZ639GPjYi306x6KAyqmWKwE7TqmT+rJjDPCmO38tcENUVFR54BCQH4gH3nXbyajNKOBbwAJLgGSgDBCX0wPxqeycvyhgMZAAbAaCQG1gKcDatWs7AanftKkP1AI2uMuF3Plaqep0BfqckyPwr+ycuxPSvd6BQKBY+fLlhwEPBoPBdGNBwWAwKRAITAQGAZ8CPYF27rpfAoFAQZyfvdgcHkee8GpIpwewClh2yvQbcNyjPnPTUpwf/uo4gd0VmHpKnQqp5jsDa935bkCVSpUqrQSeAj4DngFicH4BXIPzi/F+YIq7zWTgxDuEkW6fu8/d4fhOds7fZKClO18G53XflEWb3wPhQDV3OkLasA8AJQEN4OfMUqB2IBCoHggEMjt3BAKBdK+3W3/SoUOH9gSDwf+kKjeBQKDWiXmgE/CXu3ob0MpdVxcoyAV8oeXVXTpLgVXW2p9PXWGMGeZRn7kpEXgCmAWEAmOB1cDLOL/UpgL/hxP0icBe4MFstPsY8C/gEmCGO+G2Pxbnl+hx4AGcq305O9k5f7OANsAaIAnniu/EX20/fv755zWAKjhXnD3d+lm5B5iAzluOBIPBxEAgkObcBYPB1YFA4GXgt2AweCL87wEmBIPB1K/3XUCzYsWKJQQCgT/csgeBFcC4QCBQDOdi60+cn0WAgcCYQCDQH+fcPXhKmxcUY+2533djTCngqLX2yDlv/CJhjOnlDmHJBUjn78Ll53PnSeCn6cAJf2ut3edpRyIikiVPxvCNMVWMMROMMXHAr8BSY0ysW1bNiz5FRCRrXr1pOxGYBIRba2tba2vhvIk5GWccU0REcplXgV/GWjvRWpt0osBam2StnQBcdM8FMMa0M8YEjTEbjDHPZLC+gDFmorv+19R/5RhjnnXLg8aYtqnKx7p/Fa3KnaMQOPtzaYwpbYyZb4yJN8Z8kNv7Lell41w2M8YsN8YkGmPuyIt9zG1eBf4yY8xoY0wTY0xFd2pijBkN/O5Rn3nCGBMKjALaA5cC9xhjLj2lWk9gn/uXzt9x78l363UFLsO513e02x44d+u08/wAJEVOziVwFHgB51ZbyWPZPJfbcO7S+TJ39y7veBX49wMrgZdwbp+aDQzDua2wu0d95pXGwFCul9MAAAO1SURBVAZr7SZr7XGcIatbTqlzCzDOnf8aaGWch3XcAkyw1h6z1m7G+aBOYwBr7X9xbueU3HPW59Jae9hauwgn+CXvnfZcWmu3WGtX4HyQ0Rc8uQ/ffYE/dKeLXQSwPdVyFNAkszrW2kRjzAGcoa0InE9zpt42wrtdldPIybnUB+HOL9k5l76TFw9P65jbfXrMZFB26r2umdXJzraSe3JyLuX8ovOUgbx4PPKF+zDpjGX3uSyVAYwx+YDiOMM1Z/JcEPFeTs6lnF/0s5UBzwLfGFPHGDPYGPOeMeZdd76utfZFr/rMI0uB2saY6saYzJ7tMRXncQgAdwA/WOcTb1OBru6dH9Vxnu+yJJf2W9LLybmU80t2zqXvePXBq8E4b5IYnABb6s6Pz+j2qAuZtTb1c1nWAl9Za1cbY142xnR2q30ClDbGbAAG4DwsDWvtauArnOe1zAT6nLiV1RgzHufBTwFjTJQxpmduHpcf5eRcAhhjtgBvAw+65+zUu0Ikl2TnXBpjGhljooA7gY+MMRf9FxV49SyddcBl1tqEU8rzA6uttbXPeaciIpIlr4Z0koGKGZRXwEe3QImInE+8ejxyP2CeMWY9J2+NqoLzfPAnPOpTRESy4NnTMo0xITgffojAGb+PApamftyCiIjkHs8fjywiIueHvLgPX0RE8oACX3zDGBPvQZtbjDFl8qJvkTOlwBcR8Qmv7tIRuSAYYzoBzwP5cb6kvJu1dpcxZhhQHedW4kicD1ldg/O43WigU6rPmQwyxrR05++11m5wPzn9Jc7P2MxU/RUBpgAlgTDgeWvtFG+PUsShK3zxu0XANdba+jifDn861bqaQAecx+r+G5hvra0H/M8tP+GgtbYx8AHwjlv2LvChtbYRsDNV3aPAbdbaBkBLYKT7qGwRzynwxe8qAbOMMSuBQThfRnPCDPcqfiUQyskr9ZVAtVT1xqf691p3vmmq8s9T1TXAa8aYFcBcnNuWy5+TIxE5DQW++N37wAfulXtvoGCqdccArLXJQEKqh6Qlk3Y41GZj/oRuQFmgobX2KmDXKX2KeEaBL35XHGdMHk4+BfNM3Z3q31/c+Z9wntAITsin7i/WWpvgjvtXPcs+Rc6Y3rQVPynkPh3xhLdxvnrzP8aYaJxvH6t+Fu0WMMb8inMBdY9b1hf40hjTF/gmVd0vgGnGmN+AP4C/zqI/kbOiT9qKiPiEhnRERHxCgS8i4hMKfBERn1Dgi4j4hAJfRMQnFPgiIj6hwBcR8Yn/B5WB+yFi76KKAAAAAElFTkSuQmCC\n", + "text/html": [ + "" + ], "text/plain": [ - "
" + "" ] }, - "metadata": { - "needs_background": "light" - }, + "metadata": {}, "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 21, + "metadata": {}, + "output_type": "execute_result" } ], "source": [ @@ -838,7 +1802,7 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 20, "metadata": {}, "outputs": [], "source": [ @@ -859,22 +1823,22 @@ "\n", "1. Rendle, Steffen. \"Factorization machines.\" 2010 IEEE International Conference on Data Mining. IEEE, 2010.\n", "2. Juan, Yuchin, et al. \"Field-aware factorization machines for CTR prediction.\" Proceedings of the 10th ACM Conference on Recommender Systems. ACM, 2016.\n", - "3. Guo, Huifeng, et al. \"DeepFM: a factorization-machine based neural network for CTR prediction.\" arXiv preprint arXiv:1703.04247 (2017).\n", + "3. Guo, Huifeng, et al. \"DeepFM: a factorization-machine based neural network for CTR prediction.\" arXiv preprint arXiv:1703.04247, 2017.\n", "4. Lian, Jianxun, et al. \"xdeepfm: Combining explicit and implicit feature interactions for recommender systems.\" Proceedings of the 24th ACM SIGKDD International Conference on Knowledge Discovery & Data Mining. ACM, 2018.\n", "5. Qu, Yanru, et al. \"Product-based neural networks for user response prediction.\" 2016 IEEE 16th International Conference on Data Mining (ICDM). IEEE, 2016.\n", "6. Zhang, Weinan, Tianming Du, and Jun Wang. \"Deep learning over multi-field categorical data.\" European conference on information retrieval. Springer, Cham, 2016.\n", "7. He, Xiangnan, and Tat-Seng Chua. \"Neural factorization machines for sparse predictive analytics.\" Proceedings of the 40th International ACM SIGIR conference on Research and Development in Information Retrieval. ACM, 2017.\n", "8. Cheng, Heng-Tze, et al. \"Wide & deep learning for recommender systems.\" Proceedings of the 1st workshop on deep learning for recommender systems. ACM, 2016.\n", - "9. Langford, John, Lihong Li, and Alex Strehl. \"Vowpal wabbit online learning project.\" (2007)." + "9. Langford, John, Lihong Li, and Alex Strehl. \"Vowpal wabbit online learning project.\", 2007." ] } ], "metadata": { "celltoolbar": "Tags", "kernelspec": { - "display_name": "Python (reco_base)", + "display_name": "recommenders", "language": "python", - "name": "reco_base" + "name": "conda-env-recommenders-py" }, "language_info": { "codemirror_mode": { @@ -886,7 +1850,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.6.10" + "version": "3.6.13" } }, "nbformat": 4, diff --git a/reco_utils/common/tf_utils.py b/reco_utils/common/tf_utils.py index dcb80aaf5d..1d4f9567f7 100644 --- a/reco_utils/common/tf_utils.py +++ b/reco_utils/common/tf_utils.py @@ -206,10 +206,9 @@ def evaluation_log_hook( """Evaluation log hook for TensorFlow high-level API Estimator. .. note:: - - TensorFlow Estimator model uses the last checkpoint weights for evaluation or prediction. - In order to get the most up-to-date evaluation results while training, - set model's `save_checkpoints_steps` to be equal or greater than hook's `every_n_iter`. + TensorFlow Estimator model uses the last checkpoint weights for evaluation or prediction. + In order to get the most up-to-date evaluation results while training, + set model's `save_checkpoints_steps` to be equal or greater than hook's `every_n_iter`. Args: estimator (tf.estimator.Estimator): Model to evaluate. @@ -223,8 +222,8 @@ def evaluation_log_hook( batch_size (int): Number of samples fed into the model at a time. Note, the batch size doesn't affect on evaluation results. eval_fns (iterable of functions): List of evaluation functions that have signature of - (true_df, prediction_df, **eval_kwargs)->(float). If None, loss is calculated on true_df. - **eval_kwargs: Evaluation function's keyword arguments. + (true_df, prediction_df, \*\*eval_kwargs)->(float). If None, loss is calculated on true_df. + eval_kwargs: Evaluation function's keyword arguments. Note, prediction column name should be 'prediction' Returns: diff --git a/reco_utils/dataset/amazon_reviews.py b/reco_utils/dataset/amazon_reviews.py index 26c908d65f..a332bf19dc 100644 --- a/reco_utils/dataset/amazon_reviews.py +++ b/reco_utils/dataset/amazon_reviews.py @@ -482,7 +482,15 @@ def _data_processing(input_file): def download_and_extract(name, dest_path): - """Downloads and extracts Amazon reviews and meta datafiles if they don’t already exist""" + """Downloads and extracts Amazon reviews and meta datafiles if they don’t already exist + + Args: + name (str): Category of reviews + dest_path (str): File path for the downloaded file + + Returns: + str: File path for the extracted file + """ dirs, _ = os.path.split(dest_path) if not os.path.exists(dirs): os.makedirs(dirs) @@ -499,6 +507,7 @@ def _download_reviews(name, dest_path): """Downloads Amazon reviews datafile. Args: + name (str): Category of reviews dest_path (str): File path for the downloaded file """ diff --git a/reco_utils/dataset/criteo.py b/reco_utils/dataset/criteo.py index f31fa8e5dc..1cebefb6b4 100644 --- a/reco_utils/dataset/criteo.py +++ b/reco_utils/dataset/criteo.py @@ -27,7 +27,7 @@ def load_pandas_df(size="sample", local_cache_path=None, header=DEFAULT_HEADER): - """Loads the Criteo DAC dataset as `pandas.DataFrame. This function download, untar, and load the dataset. + """Loads the Criteo DAC dataset as `pandas.DataFrame`. This function download, untar, and load the dataset. The dataset consists of a portion of Criteo’s traffic over a period of 24 days. Each row corresponds to a display ad served by Criteo and the first @@ -164,6 +164,14 @@ def extract_criteo(size, compressed_file, path=None): def get_spark_schema(header=DEFAULT_HEADER): + """Get Spark schema from header. + + Args: + header (list): Dataset header names. + + Returns: + pyspark.sql.types.StructType: Spark schema. + """ ## create schema schema = StructType() ## do label + ints diff --git a/reco_utils/dataset/download_utils.py b/reco_utils/dataset/download_utils.py index 0d0d19d6ad..8f88566c81 100644 --- a/reco_utils/dataset/download_utils.py +++ b/reco_utils/dataset/download_utils.py @@ -9,10 +9,31 @@ from contextlib import contextmanager from tempfile import TemporaryDirectory from tqdm import tqdm +import backoff +import logging + log = logging.getLogger(__name__) +def _retry_logger(details): + log.info( + "Backing off {wait:0.1f} seconds after {tries} tries " + "calling function {target} with args {args} and kwargs " + "{kwargs}".format(**details) + ) + + +@backoff.on_exception( + backoff.expo, + ( + requests.exceptions.HTTPError, + requests.exceptions.ChunkedEncodingError, + requests.exceptions.ConnectionError, + ), + max_tries=5, + on_backoff=_retry_logger, +) def maybe_download(url, filename=None, work_directory=".", expected_bytes=None): """Download a file if it is not already downloaded. @@ -21,7 +42,7 @@ def maybe_download(url, filename=None, work_directory=".", expected_bytes=None): work_directory (str): Working directory. url (str): URL of the file to download. expected_bytes (int): Expected file size in bytes. - + Returns: str: File path of the file downloaded. """ @@ -57,8 +78,8 @@ def maybe_download(url, filename=None, work_directory=".", expected_bytes=None): @contextmanager def download_path(path=None): - """Return a path to download data. If `path=None`, then it yields a temporal path that is eventually deleted, - otherwise the real path of the input. + """Return a path to download data. If `path=None`, then it yields a temporal path that is eventually deleted, + otherwise the real path of the input. Args: path (str): Path to download data. @@ -82,7 +103,7 @@ def download_path(path=None): yield path -def unzip_file(zip_src, dst_dir, clean_zip_file=True): +def unzip_file(zip_src, dst_dir, clean_zip_file=False): """Unzip a file Args: diff --git a/reco_utils/dataset/mind.py b/reco_utils/dataset/mind.py index 3bbe681316..14a7aa8369 100644 --- a/reco_utils/dataset/mind.py +++ b/reco_utils/dataset/mind.py @@ -231,7 +231,15 @@ def get_words_and_entities(train_news, valid_news): return news_words, news_entities -def download_and_extract_globe(dest_path): +def download_and_extract_glove(dest_path): + """Download and extract the Glove embedding + + Args: + dest_path (str): Destination directory path for the downloaded file + + Returns: + str: File path where Glove was extracted. + """ url = "http://nlp.stanford.edu/data/glove.6B.zip" filepath = maybe_download(url=url, work_directory=dest_path) glove_path = os.path.join(dest_path, "glove") @@ -269,7 +277,7 @@ def generate_embeddings( ) logger.info("Downloading glove...") - glove_path = download_and_extract_globe(data_path) + glove_path = download_and_extract_glove(data_path) word_set = set() word_embedding_dict = {} diff --git a/reco_utils/dataset/sparse.py b/reco_utils/dataset/sparse.py index 901f4c9ccc..5f5e43e2ef 100644 --- a/reco_utils/dataset/sparse.py +++ b/reco_utils/dataset/sparse.py @@ -110,12 +110,12 @@ def gen_affinity_matrix(self): As a first step, two new columns are added to the input DF, containing the index maps generated by the gen_index() method. The new indices, together with the ratings, are then used to generate the user/item affinity matrix using scipy's sparse matrix method - coo_matrix; for reference see: - https://docs.scipy.org/doc/scipy/reference/generated/scipy.sparse.coo_matrix.html + coo_matrix; for reference see: + https://docs.scipy.org/doc/scipy/reference/generated/scipy.sparse.coo_matrix.html. The input format is: `coo_matrix((data, (rows, columns)), shape=(rows, columns))` + Returns: - scipy.sparse.coo_matrix: user-affinity matrix of dimensions (Nusers, Nitems) in numpy format. Unrated movies - are assigned a value of 0. + scipy.sparse.coo_matrix: user-affinity matrix of dimensions (Nusers, Nitems) in numpy format. Unrated movies are assigned a value of 0. """ log.info("Generating the user/item affinity matrix...") diff --git a/reco_utils/evaluation/python_evaluation.py b/reco_utils/evaluation/python_evaluation.py index 35219cbd01..0414e0f936 100644 --- a/reco_utils/evaluation/python_evaluation.py +++ b/reco_utils/evaluation/python_evaluation.py @@ -32,11 +32,15 @@ def check_column_dtypes(func): """Checks columns of DataFrame inputs This includes the checks on: - 1. whether the input columns exist in the input DataFrames - 2. whether the data types of col_user as well as col_item are matched in the two input DataFrames. + + * whether the input columns exist in the input DataFrames + * whether the data types of col_user as well as col_item are matched in the two input DataFrames. Args: func (function): function that will be wrapped + + Returns: + function: Wrapper function for checking dtypes. """ @wraps(func) diff --git a/reco_utils/evaluation/spark_evaluation.py b/reco_utils/evaluation/spark_evaluation.py index 93542ab3ca..e8dde7b371 100644 --- a/reco_utils/evaluation/spark_evaluation.py +++ b/reco_utils/evaluation/spark_evaluation.py @@ -132,17 +132,17 @@ def mae(self): return self.metrics.meanAbsoluteError def rsquared(self): - """Calculate R squared + """Calculate R squared. + Returns: - float: R squared + float: R squared. """ return self.metrics.r2 def exp_var(self): """Calculate explained variance. - NOTE: - Spark MLLib's implementation is buggy (can lead to values > 1), hence we use var(). + :note: Spark MLLib's implementation is buggy (can lead to values > 1), hence we use var(). Returns: float: Explained variance (min=0, max=1). diff --git a/reco_utils/recommender/newsrec/models/base_model.py b/reco_utils/recommender/newsrec/models/base_model.py index a175f0e5dd..2e3305bc1c 100644 --- a/reco_utils/recommender/newsrec/models/base_model.py +++ b/reco_utils/recommender/newsrec/models/base_model.py @@ -19,8 +19,8 @@ class BaseModel: Attributes: hparams (obj): A tf.contrib.training.HParams object, hold the entire set of hyperparameters. - iterator_creator_train (obj): An iterator to load the data in training steps. - iterator_creator_train (obj): An iterator to load the data in testing steps. + train_iterator (obj): An iterator to load the data in training steps. + test_iterator (obj): An iterator to load the data in testing steps. graph (obj): An optional graph. seed (int): Random seed. """ @@ -36,8 +36,7 @@ def __init__( Args: hparams (obj): A tf.contrib.training.HParams object, hold the entire set of hyperparameters. - iterator_creator_train (obj): An iterator to load the data in training steps. - iterator_creator_train (obj): An iterator to load the data in testing steps. + iterator_creator (obj): An iterator to load the data. graph (obj): An optional graph. seed (int): Random seed. """ diff --git a/reco_utils/tuning/parameter_sweep.py b/reco_utils/tuning/parameter_sweep.py index 45c21ece22..8942be9b42 100644 --- a/reco_utils/tuning/parameter_sweep.py +++ b/reco_utils/tuning/parameter_sweep.py @@ -9,19 +9,24 @@ def generate_param_grid(params): """Generator of parameter grids. Generate parameter lists from a parameter dictionary in the form of: + .. code-block:: python - { - "param1": [value1, value2], - "param2": [value1, value2] - } + + { + "param1": [value1, value2], + "param2": [value1, value2] + } + to: + .. code-block:: python - [ - {"param1": value1, "param2": value1}, - {"param1": value2, "param2": value1}, - {"param1": value1, "param2": value2}, - {"param1": value2, "param2": value2} - ] + + [ + {"param1": value1, "param2": value1}, + {"param1": value2, "param2": value1}, + {"param1": value1, "param2": value2}, + {"param1": value2, "param2": value2} + ] Args: param_dict (dict): dictionary of parameters and values (in a list). diff --git a/setup.py b/setup.py index b938271f92..7b3a37fefd 100644 --- a/setup.py +++ b/setup.py @@ -34,7 +34,7 @@ "lightgbm>=2.2.1,<3", "memory_profiler>=0.54.0,<1", "nltk>=3.4,<4", - "pydocumentdb>=2.3.3<3", # todo: replace with azure-cosmos + "pydocumentdb>=2.3.3<3", # TODO: replace with azure-cosmos "pymanopt>=0.2.5,<1", "seaborn>=0.8.1,<1", "transformers>=2.5.0,<5", @@ -44,11 +44,11 @@ "pyyaml>=5.4.1,<6", "requests>=2.0.0,<3", "cornac>=1.1.2,<2", - "scikit-surprise>=0.19.1,<=1.1.1" + "scikit-surprise>=0.19.1,<=1.1.1", + "backoff>=1.8.0", ] # shared dependencies - extras_require = { "examples": [ "azure.mgmt.cosmosdb>=0.8.0,<1", @@ -61,8 +61,8 @@ ], "gpu": [ "nvidia-ml-py3>=7.352.0", - "tensorflow-gpu>=1.15.0,<2", # compiled with cuda 10.0 - "torch==1.2.0", # last os-common version with cuda 10.0 support + "tensorflow-gpu>=1.15.0,<2", # compiled with CUDA 10.0 + "torch==1.2.0", # last os-common version with CUDA 10.0 support "fastai>=1.0.46,<2", ], "spark": [ @@ -94,7 +94,7 @@ url="https://github.com/microsoft/recommenders", project_urls={ "Documentation": "https://microsoft-recommenders.readthedocs.io/en/stable/", - "Wiki": "https://github.com/microsoft/recommenders/wiki" + "Wiki": "https://github.com/microsoft/recommenders/wiki", }, author="RecoDev Team at Microsoft", author_email="RecoDevTeam@service.microsoft.com", @@ -109,11 +109,11 @@ "Programming Language :: Python :: 3.6", "Operating System :: Microsoft :: Windows", "Operating System :: POSIX :: Linux", - "Operating System :: MacOS" + "Operating System :: MacOS", ], extras_require=extras_require, keywords="recommendations recommenders recommender system engine " - "machine learning python spark gpu", + "machine learning python spark gpu", install_requires=install_requires, package_dir={"reco_utils": "reco_utils"}, packages=find_packages(where=".", exclude=["tests", "tools", "examples"]), diff --git a/tests/ci/azure_pipeline_test/dsvm_linux_template.yml b/tests/ci/azure_pipeline_test/dsvm_linux_template.yml index 1f0e918e37..d72446c64e 100644 --- a/tests/ci/azure_pipeline_test/dsvm_linux_template.yml +++ b/tests/ci/azure_pipeline_test/dsvm_linux_template.yml @@ -60,11 +60,10 @@ jobs: echo " --- BUILDING PACKAGE ---" rm -rf dist - pip install setuptools wheel twine || exit -1 python setup.py sdist bdist_wheel --plat-name=$PLATFORM || exit -1 - ls -lha dist + echo " --- INSTALLING WHEEL ---" - pip install --user dist/ms_recommenders-$RELEASE_VERSION-py3-none-$PLATFORM.whl${{ parameters.pip_opts }} --force-reinstall || exit -1 + pip install dist/ms_recommenders-$RELEASE_VERSION-py3-none-$PLATFORM.whl${{ parameters.pip_opts }} || exit -1 else echo "Installing latest code" pip install .${{ parameters.pip_opts }} || exit -1 @@ -80,6 +79,7 @@ jobs: - ${{ each test in parameters.test_types }}: - script: | eval "$(conda shell.bash hook)" + conda activate ${{ parameters.conda_env }} pip install pytest>=3.6.4 || exit -1 @@ -100,6 +100,11 @@ jobs: export TEST_MARKER="${{ test }} and " fi + # Remove reco_utils folder to make sure it is using the package to run the tests + # We leave the yaml files from DeepRec that are used in the tests + find reco_utils ! -name "*.yaml" -delete 2>/dev/null + du -a reco_utils + # run tests pytest tests/${{ test }} \ ${{ parameters.pytest_params }} \ diff --git a/tests/ci/azure_pipeline_test/release_pipeline.yml b/tests/ci/azure_pipeline_test/release_pipeline.yml index db1ca98fa5..4d32b3a635 100644 --- a/tests/ci/azure_pipeline_test/release_pipeline.yml +++ b/tests/ci/azure_pipeline_test/release_pipeline.yml @@ -6,10 +6,10 @@ pr: none # A new tag will trigger the build trigger: -- pipeline_release - # tags: - # include: - # - * + tags: + include: + - 0.* + - 1.* variables: - group: LinuxAgentPool @@ -26,102 +26,102 @@ jobs: pip_opts: "" pytest_markers: "not notebooks and not spark and not gpu" install: "release" - -- template: dsvm_linux_template.yml - parameters: - test_types: - - unit - task_name: "Test - Unit Notebook Linux CPU" - conda_env: "release_unit_notebook_linux_cpu" - conda_opts: "python=3.6" - pip_opts: "[examples]" - pytest_markers: "notebooks and not spark and not gpu" - install: "release" - -- template: dsvm_linux_template.yml - parameters: - test_types: - - unit - task_name: "Test - Unit Linux GPU" - conda_env: "release_unit_linux_gpu" - conda_opts: "python=3.6 cudatoolkit=10.0 \"cudnn>=7.6\"" - pip_opts: "[gpu] -f https://download.pytorch.org/whl/cu100/torch_stable.html" - pytest_markers: "not notebooks and not spark and gpu" - install: "release" - -- template: dsvm_linux_template.yml - parameters: - test_types: - - unit - task_name: "Test - Unit Notebook Linux GPU" - conda_env: "release_unit_notebook_linux_gpu" - conda_opts: "python=3.6 cudatoolkit=10.0 \"cudnn>=7.6\"" - pip_opts: "[gpu,examples] -f https://download.pytorch.org/whl/cu100/torch_stable.html" - pytest_markers: "notebooks and not spark and gpu" - install: "release" - -- template: dsvm_linux_template.yml - parameters: - test_types: - - unit - task_name: "Test - Unit Linux Spark" - conda_env: "release_unit_linux_spark" - conda_opts: "python=3.6" - pip_opts: "[spark]" - pytest_markers: "not notebooks and spark and not gpu" - install: "release" - -- template: dsvm_linux_template.yml - parameters: - test_types: - - unit - task_name: "Test - Unit Notebook Linux Spark" - conda_env: "release_unit_notebook_linux_spark" - conda_opts: "python=3.6" - pip_opts: "[spark,examples]" - pytest_markers: "notebooks and spark and not gpu" - install: "release" - -# ====== Nightly tests ====== -- template: dsvm_linux_template.yml - parameters: - test_types: - - smoke - - integration - task_name: "Test - Nightly Linux CPU" - timeout: 180 - conda_env: "release_nightly_linux_cpu" - conda_opts: "python=3.6" - pip_opts: "[examples]" - pytest_markers: "not spark and not gpu" - install: "release" - -- template: dsvm_linux_template.yml - parameters: - test_types: - - smoke - - integration - task_name: "Test - Nightly Linux GPU" - timeout: 240 - conda_env: "release_nightly_linux_gpu" - conda_opts: "python=3.6 cudatoolkit=10.0 \"cudnn>=7.6\"" - pip_opts: "[gpu,examples] -f https://download.pytorch.org/whl/cu100/torch_stable.html" - pytest_markers: "not spark and gpu" - install: "release" - -- template: dsvm_linux_template.yml - parameters: - test_types: - - smoke - - integration - task_name: "Test - Nightly Linux Spark" - timeout: 180 - conda_env: "release_nightly_linux_spark" - conda_opts: "python=3.6" - pip_opts: "[spark,examples]" - pytest_markers: "spark and not gpu" - install: "release" - package: "publish" # We want to publish to the package limbo the latest wheel + package: "publish" + +# - template: dsvm_linux_template.yml +# parameters: +# test_types: +# - unit +# task_name: "Test - Unit Notebook Linux CPU" +# conda_env: "release_unit_notebook_linux_cpu" +# conda_opts: "python=3.6" +# pip_opts: "[examples]" +# pytest_markers: "notebooks and not spark and not gpu" +# install: "release" + +# - template: dsvm_linux_template.yml +# parameters: +# test_types: +# - unit +# task_name: "Test - Unit Linux GPU" +# conda_env: "release_unit_linux_gpu" +# conda_opts: "python=3.6 cudatoolkit=10.0 \"cudnn>=7.6\"" +# pip_opts: "[gpu] -f https://download.pytorch.org/whl/cu100/torch_stable.html" +# pytest_markers: "not notebooks and not spark and gpu" +# install: "release" + +# - template: dsvm_linux_template.yml +# parameters: +# test_types: +# - unit +# task_name: "Test - Unit Notebook Linux GPU" +# conda_env: "release_unit_notebook_linux_gpu" +# conda_opts: "python=3.6 cudatoolkit=10.0 \"cudnn>=7.6\"" +# pip_opts: "[gpu,examples] -f https://download.pytorch.org/whl/cu100/torch_stable.html" +# pytest_markers: "notebooks and not spark and gpu" +# install: "release" + +# - template: dsvm_linux_template.yml +# parameters: +# test_types: +# - unit +# task_name: "Test - Unit Linux Spark" +# conda_env: "release_unit_linux_spark" +# conda_opts: "python=3.6" +# pip_opts: "[spark]" +# pytest_markers: "not notebooks and spark and not gpu" +# install: "release" + +# - template: dsvm_linux_template.yml +# parameters: +# test_types: +# - unit +# task_name: "Test - Unit Notebook Linux Spark" +# conda_env: "release_unit_notebook_linux_spark" +# conda_opts: "python=3.6" +# pip_opts: "[spark,examples]" +# pytest_markers: "notebooks and spark and not gpu" +# install: "release" + +# # ====== Nightly tests ====== +# - template: dsvm_linux_template.yml +# parameters: +# test_types: +# - smoke +# - integration +# task_name: "Test - Nightly Linux CPU" +# timeout: 180 +# conda_env: "release_nightly_linux_cpu" +# conda_opts: "python=3.6" +# pip_opts: "[examples]" +# pytest_markers: "not spark and not gpu" +# install: "release" + +# - template: dsvm_linux_template.yml +# parameters: +# test_types: +# - smoke +# - integration +# task_name: "Test - Nightly Linux GPU" +# timeout: 240 +# conda_env: "release_nightly_linux_gpu" +# conda_opts: "python=3.6 cudatoolkit=10.0 \"cudnn>=7.6\"" +# pip_opts: "[gpu,examples] -f https://download.pytorch.org/whl/cu100/torch_stable.html" +# pytest_markers: "not spark and gpu" +# install: "release" + +# - template: dsvm_linux_template.yml +# parameters: +# test_types: +# - smoke +# - integration +# task_name: "Test - Nightly Linux Spark" +# timeout: 180 +# conda_env: "release_nightly_linux_spark" +# conda_opts: "python=3.6" +# pip_opts: "[spark,examples]" +# pytest_markers: "spark and not gpu" +# install: "release" # ====== Publish release ====== - job: Package @@ -129,20 +129,20 @@ jobs: name: $(Agent_Pool) dependsOn: - release_unit_linux_cpu - - release_unit_notebook_linux_cpu - - release_unit_linux_gpu - - release_unit_notebook_linux_gpu - - release_unit_linux_spark - - release_unit_notebook_linux_spark - - release_nightly_linux_cpu - - release_nightly_linux_gpu - - release_nightly_linux_spark + # - release_unit_notebook_linux_cpu + # - release_unit_linux_gpu + # - release_unit_notebook_linux_gpu + # - release_unit_linux_spark + # - release_unit_notebook_linux_spark + # - release_nightly_linux_cpu + # - release_nightly_linux_gpu + # - release_nightly_linux_spark condition: succeeded() steps: # Create archives with complete source code included - task: ArchiveFiles@2 displayName: Create zip archive of reco_utils - condition: succeeded() #and(succeeded(), startsWith(variables['Build.SourceBranch'], 'refs/tags/')) + condition: and(succeeded(), startsWith(variables['Build.SourceBranch'], 'refs/tags/')) inputs: rootFolderOrFile: $(Build.SourcesDirectory)/reco_utils includeRootFolder: false @@ -152,7 +152,7 @@ jobs: - task: ArchiveFiles@2 displayName: Create tar.gz archive of reco_utils - condition: succeeded() #and(succeeded(), startsWith(variables['Build.SourceBranch'], 'refs/tags/')) + condition: and(succeeded(), startsWith(variables['Build.SourceBranch'], 'refs/tags/')) inputs: rootFolderOrFile: $(Build.SourcesDirectory)/reco_utils includeRootFolder: false @@ -163,7 +163,7 @@ jobs: - task: ArchiveFiles@2 displayName: Create zip archive of examples - condition: succeeded() #and(succeeded(), startsWith(variables['Build.SourceBranch'], 'refs/tags/')) + condition: and(succeeded(), startsWith(variables['Build.SourceBranch'], 'refs/tags/')) inputs: rootFolderOrFile: $(Build.SourcesDirectory)/examples includeRootFolder: false @@ -173,7 +173,7 @@ jobs: - task: ArchiveFiles@2 displayName: Create tar.gz archive of examples - condition: succeeded() #and(succeeded(), startsWith(variables['Build.SourceBranch'], 'refs/tags/')) + condition: and(succeeded(), startsWith(variables['Build.SourceBranch'], 'refs/tags/')) inputs: rootFolderOrFile: $(Build.SourcesDirectory)/examples includeRootFolder: false @@ -184,14 +184,14 @@ jobs: - task: DownloadPipelineArtifact@2 # Documentation: https://docs.microsoft.com/en-us/azure/devops/pipelines/artifacts/pipeline-artifacts?view=azure-devops&tabs=yaml-task displayName: 'Download Artifacts from Shared Storage' - condition: succeeded() #and(succeeded(), startsWith(variables['Build.SourceBranch'], 'refs/tags/')) + condition: and(succeeded(), startsWith(variables['Build.SourceBranch'], 'refs/tags/')) inputs: artifact: PackageAssets path: $(Build.SourcesDirectory)/binaries - task: GitHubRelease@0 # Documentation: https://docs.microsoft.com/en-us/azure/devops/pipelines/tasks/utility/github-release?view=azure-devops displayName: 'Create GitHub Draft Release' - condition: succeeded() #and(succeeded(), startsWith(variables['Build.SourceBranch'], 'refs/tags/')) + condition: and(succeeded(), startsWith(variables['Build.SourceBranch'], 'refs/tags/')) inputs: gitHubConnection: recommenders_release repositoryName: '$(Build.Repository.Name)' diff --git a/tests/integration/examples/test_notebooks_gpu.py b/tests/integration/examples/test_notebooks_gpu.py index c23a8eaa64..8467fcc06d 100644 --- a/tests/integration/examples/test_notebooks_gpu.py +++ b/tests/integration/examples/test_notebooks_gpu.py @@ -206,24 +206,6 @@ def test_xdeepfm_integration( ) -# TODO: remove tf dependency in this notebook and drop gpu marker -@pytest.mark.gpu -@pytest.mark.integration -def test_xlearn_fm_integration(notebooks, output_notebook, kernel_name): - notebook_path = notebooks["xlearn_fm_deep_dive"] - pm.execute_notebook( - notebook_path, - output_notebook, - kernel_name=kernel_name, - parameters=dict(LEARNING_RATE=0.2, EPOCH=10), - ) - results = sb.read_notebook(output_notebook).scraps.dataframe.set_index("name")[ - "data" - ] - - assert results["auc_score"] == pytest.approx(0.75, rel=TOL, abs=ABS_TOL) - - @pytest.mark.gpu @pytest.mark.integration @pytest.mark.parametrize( diff --git a/tests/integration/examples/test_notebooks_python.py b/tests/integration/examples/test_notebooks_python.py index 4802e831e5..30827dd3e7 100644 --- a/tests/integration/examples/test_notebooks_python.py +++ b/tests/integration/examples/test_notebooks_python.py @@ -3,6 +3,7 @@ import sys import pytest + try: import papermill as pm import scrapbook as sb @@ -166,7 +167,7 @@ def test_vw_deep_dive_integration( assert results[key] == pytest.approx(value, rel=TOL, abs=ABS_TOL) -#@pytest.mark.skipif(sys.platform == "win32", reason="nni not installable on windows") +# @pytest.mark.skipif(sys.platform == "win32", reason="nni not installable on windows") @pytest.mark.integration @pytest.mark.skip(reason="Tests removed due to installation incompatibilities") def test_nni_tuning_svd(notebooks, output_notebook, kernel_name, tmp): @@ -246,3 +247,19 @@ def test_geoimc_integration(notebooks, output_notebook, kernel_name, expected_va for key, value in expected_values.items(): assert results[key] == pytest.approx(value, rel=TOL, abs=ABS_TOL) + + +@pytest.mark.integration +def test_xlearn_fm_integration(notebooks, output_notebook, kernel_name): + notebook_path = notebooks["xlearn_fm_deep_dive"] + pm.execute_notebook( + notebook_path, + output_notebook, + kernel_name=kernel_name, + parameters=dict(LEARNING_RATE=0.2, EPOCH=10), + ) + results = sb.read_notebook(output_notebook).scraps.dataframe.set_index("name")[ + "data" + ] + + assert results["auc_score"] == pytest.approx(0.75, rel=TOL, abs=ABS_TOL) diff --git a/tests/unit/reco_utils/dataset/test_dataset.py b/tests/unit/reco_utils/dataset/test_dataset.py index 41dfe3a338..d6c92037d2 100644 --- a/tests/unit/reco_utils/dataset/test_dataset.py +++ b/tests/unit/reco_utils/dataset/test_dataset.py @@ -2,21 +2,65 @@ # Licensed under the MIT License. import os -import sys import pytest from tempfile import TemporaryDirectory +import logging from reco_utils.dataset.download_utils import maybe_download, download_path -def test_maybe_download(): +@pytest.fixture +def files_fixtures(): file_url = "https://raw.githubusercontent.com/Microsoft/Recommenders/main/LICENSE" filepath = "license.txt" - assert not os.path.exists(filepath) - filepath = maybe_download(file_url, "license.txt", expected_bytes=1162) - assert os.path.exists(filepath) - os.remove(filepath) + return file_url, filepath + + +def test_maybe_download(files_fixtures): + file_url, filepath = files_fixtures + if os.path.exists(filepath): + os.remove(filepath) + + downloaded_filepath = maybe_download(file_url, "license.txt", expected_bytes=1162) + assert os.path.exists(downloaded_filepath) + assert downloaded_filepath.split("/")[-1] == "license.txt" + + +def test_maybe_download_wrong_bytes(caplog, files_fixtures): + caplog.clear() + caplog.set_level(logging.INFO) + + file_url, filepath = files_fixtures + if os.path.exists(filepath): + os.remove(filepath) + with pytest.raises(IOError): filepath = maybe_download(file_url, "license.txt", expected_bytes=0) + assert "Failed to verify license.txt" in caplog.text + + +def test_maybe_download_maybe(caplog, files_fixtures): + caplog.clear() + caplog.set_level(logging.INFO) + + file_url, filepath = files_fixtures + if os.path.exists(filepath): + os.remove(filepath) + + downloaded_filepath = maybe_download(file_url, "license.txt") + assert os.path.exists(downloaded_filepath) + maybe_download(file_url, "license.txt") + assert "File ./license.txt already downloaded" in caplog.text + + +# def test_maybe_download_retry(caplog): +# TODO: consider https://github.com/rholder/retrying/blob/master/retrying.py +# caplog.clear() +# caplog.set_level(logging.INFO) + +# maybe_download( +# "https://raw.githubusercontent.com/Microsoft/Recommenders/main/non_existing_file.zip" +# ) +# assert "Backing off" in caplog.text def test_download_path(): diff --git a/tests/unit/reco_utils/recommender/test_deeprec_model.py b/tests/unit/reco_utils/recommender/test_deeprec_model.py index 81971a2823..6786ec79ad 100644 --- a/tests/unit/reco_utils/recommender/test_deeprec_model.py +++ b/tests/unit/reco_utils/recommender/test_deeprec_model.py @@ -54,7 +54,8 @@ def test_xdeepfm_component_definition(deeprec_resource_path): @pytest.mark.gpu -def test_dkn_component_definition(deeprec_resource_path): +@pytest.fixture(scope="module") +def dkn_files(deeprec_resource_path): data_path = os.path.join(deeprec_resource_path, "dkn") yaml_file = os.path.join(data_path, "dkn.yaml") news_feature_file = os.path.join(data_path, r"doc_feature.txt") @@ -68,7 +69,31 @@ def test_dkn_component_definition(deeprec_resource_path): data_path, "mind-demo.zip", ) + return ( + data_path, + yaml_file, + news_feature_file, + user_history_file, + wordEmb_file, + entityEmb_file, + contextEmb_file, + ) + +@pytest.mark.gpu +def test_dkn_component_definition(dkn_files): + # Load params from fixture + ( + _, + yaml_file, + news_feature_file, + user_history_file, + wordEmb_file, + entityEmb_file, + contextEmb_file, + ) = dkn_files + + # Test DKN model hparams = prepare_hparams( yaml_file, news_feature_file=news_feature_file, @@ -80,13 +105,27 @@ def test_dkn_component_definition(deeprec_resource_path): learning_rate=0.0001, ) assert hparams is not None - model = DKN(hparams, DKNTextIterator) + model = DKN(hparams, DKNTextIterator) assert model.logit is not None assert model.update is not None assert model.iterator is not None - ### test DKN's item2item version + +@pytest.mark.gpu +def test_dkn_item2item_component_definition(dkn_files): + # Load params from fixture + ( + data_path, + yaml_file, + news_feature_file, + _, + wordEmb_file, + entityEmb_file, + contextEmb_file, + ) = dkn_files + + # Test DKN's item2item version hparams = prepare_hparams( yaml_file, news_feature_file=news_feature_file, @@ -101,86 +140,150 @@ def test_dkn_component_definition(deeprec_resource_path): use_entity=True, use_context=True, ) - hparams.neg_num = 9 assert hparams is not None - model_item2item = DKNItem2Item(hparams, DKNItem2itemTextIterator) + hparams.neg_num = 9 + model_item2item = DKNItem2Item(hparams, DKNItem2itemTextIterator) assert model_item2item.pred_logits is not None assert model_item2item.update is not None assert model_item2item.iterator is not None @pytest.mark.gpu -def test_slirec_component_definition(deeprec_resource_path, deeprec_config_path): +@pytest.fixture(scope="module") +def sequential_files(deeprec_resource_path): data_path = os.path.join(deeprec_resource_path, "slirec") - yaml_file = os.path.join(deeprec_config_path, "sli_rec.yaml") - yaml_file_nextitnet = os.path.join(deeprec_config_path, "nextitnet.yaml") - yaml_file_sum = os.path.join(deeprec_config_path, "sum.yaml") train_file = os.path.join(data_path, r"train_data") + valid_file = os.path.join(data_path, r"valid_data") + test_file = os.path.join(data_path, r"test_data") + user_vocab = os.path.join(data_path, r"user_vocab.pkl") + item_vocab = os.path.join(data_path, r"item_vocab.pkl") + cate_vocab = os.path.join(data_path, r"category_vocab.pkl") - if not os.path.exists(train_file): - train_file = os.path.join(data_path, r"train_data") - valid_file = os.path.join(data_path, r"valid_data") - test_file = os.path.join(data_path, r"test_data") - user_vocab = os.path.join(data_path, r"user_vocab.pkl") - item_vocab = os.path.join(data_path, r"item_vocab.pkl") - cate_vocab = os.path.join(data_path, r"category_vocab.pkl") - - reviews_name = "reviews_Movies_and_TV_5.json" - meta_name = "meta_Movies_and_TV.json" - reviews_file = os.path.join(data_path, reviews_name) - meta_file = os.path.join(data_path, meta_name) - valid_num_ngs = ( - 4 # number of negative instances with a positive instance for validation - ) - test_num_ngs = ( - 9 # number of negative instances with a positive instance for testing - ) - sample_rate = ( - 0.01 # sample a small item set for training and testing here for example - ) + reviews_name = "reviews_Movies_and_TV_5.json" + meta_name = "meta_Movies_and_TV.json" + reviews_file = os.path.join(data_path, reviews_name) + meta_file = os.path.join(data_path, meta_name) + valid_num_ngs = ( + 4 # number of negative instances with a positive instance for validation + ) + test_num_ngs = ( + 9 # number of negative instances with a positive instance for testing + ) + sample_rate = ( + 0.01 # sample a small item set for training and testing here for example + ) - input_files = [ - reviews_file, - meta_file, - train_file, - valid_file, - test_file, - user_vocab, - item_vocab, - cate_vocab, - ] - download_and_extract(reviews_name, reviews_file) - download_and_extract(meta_name, meta_file) - data_preprocessing( - *input_files, - sample_rate=sample_rate, - valid_num_ngs=valid_num_ngs, - test_num_ngs=test_num_ngs - ) + input_files = [ + reviews_file, + meta_file, + train_file, + valid_file, + test_file, + user_vocab, + item_vocab, + cate_vocab, + ] + download_and_extract(reviews_name, reviews_file) + download_and_extract(meta_name, meta_file) + data_preprocessing( + *input_files, + sample_rate=sample_rate, + valid_num_ngs=valid_num_ngs, + test_num_ngs=test_num_ngs + ) + + return ( + data_path, + user_vocab, + item_vocab, + cate_vocab, + ) + + +@pytest.mark.gpu +def test_slirec_component_definition(sequential_files, deeprec_config_path): + yaml_file = os.path.join(deeprec_config_path, "sli_rec.yaml") + data_path, user_vocab, item_vocab, cate_vocab = sequential_files hparams = prepare_hparams( - yaml_file, train_num_ngs=4 - ) # confirm the train_num_ngs when initializing a SLi_Rec model. - model = SLI_RECModel(hparams, SequentialIterator) - # nextitnet model - hparams_nextitnet = prepare_hparams(yaml_file_nextitnet, train_num_ngs=4) - model_nextitnet = NextItNetModel(hparams_nextitnet, NextItNetIterator) - # sum model - hparams_sum = prepare_hparams( - yaml_file_sum, train_num_ngs=4 - ) # confirm the train_num_ngs when initializing a SLi_Rec model. - model_sum = SUMModel(hparams_sum, SequentialIterator) + yaml_file, + train_num_ngs=4, + embed_l2=0.0, + layer_l2=0.0, + learning_rate=0.001, + epochs=1, + MODEL_DIR=os.path.join(data_path, "model"), + SUMMARIES_DIR=os.path.join(data_path, "summary"), + user_vocab=user_vocab, + item_vocab=item_vocab, + cate_vocab=cate_vocab, + need_sample=True, + ) + assert hparams is not None + model = SLI_RECModel(hparams, SequentialIterator) assert model.logit is not None assert model.update is not None assert model.iterator is not None + +@pytest.mark.gpu +def test_nextitnet_component_definition(sequential_files, deeprec_config_path): + yaml_file_nextitnet = os.path.join(deeprec_config_path, "nextitnet.yaml") + data_path, user_vocab, item_vocab, cate_vocab = sequential_files + + # NextItNet model + hparams_nextitnet = prepare_hparams( + yaml_file_nextitnet, + train_num_ngs=4, + embed_l2=0.0, + layer_l2=0.0, + learning_rate=0.001, + epochs=1, + MODEL_DIR=os.path.join(data_path, "model"), + SUMMARIES_DIR=os.path.join(data_path, "summary"), + user_vocab=user_vocab, + item_vocab=item_vocab, + cate_vocab=cate_vocab, + need_sample=True, + ) + assert hparams_nextitnet is not None + + model_nextitnet = NextItNetModel(hparams_nextitnet, NextItNetIterator) assert model_nextitnet.logit is not None assert model_nextitnet.update is not None assert model_nextitnet.iterator is not None +@pytest.mark.gpu +def test_sum_component_definition(sequential_files, deeprec_config_path): + yaml_file_sum = os.path.join(deeprec_config_path, "sum.yaml") + data_path, user_vocab, item_vocab, cate_vocab = sequential_files + + # SUM model + hparams_sum = prepare_hparams( + yaml_file_sum, + train_num_ngs=4, + embed_l2=0.0, + layer_l2=0.0, + learning_rate=0.001, + epochs=1, + MODEL_DIR=os.path.join(data_path, "model"), + SUMMARIES_DIR=os.path.join(data_path, "summary"), + user_vocab=user_vocab, + item_vocab=item_vocab, + cate_vocab=cate_vocab, + need_sample=True, + ) + assert hparams_sum is not None + + model_sum = SUMModel(hparams_sum, SequentialIterator) + assert model_sum.logit is not None + assert model_sum.update is not None + assert model_sum.iterator is not None + + @pytest.mark.gpu def test_lightgcn_component_definition(deeprec_config_path): yaml_file = os.path.join(deeprec_config_path, "lightgcn.yaml") diff --git a/tests/unit/reco_utils/recommender/test_geoimc.py b/tests/unit/reco_utils/recommender/test_geoimc.py index 3c9177d982..717c29e185 100644 --- a/tests/unit/reco_utils/recommender/test_geoimc.py +++ b/tests/unit/reco_utils/recommender/test_geoimc.py @@ -7,7 +7,7 @@ import numpy as np import pandas as pd from scipy.sparse import csr_matrix -from pandas.util.testing import assert_frame_equal +from pandas.testing import assert_frame_equal from reco_utils.common.python_utils import binarize from reco_utils.recommender.geoimc.geoimc_data import DataPtr