Skip to content

Commit

Permalink
fix nb
Browse files Browse the repository at this point in the history
  • Loading branch information
marcocuturi committed Nov 14, 2022
1 parent f099201 commit 9ba0879
Showing 1 changed file with 11 additions and 43 deletions.
54 changes: 11 additions & 43 deletions docs/notebooks/point_clouds.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -216989,7 +216989,7 @@
},
{
"cell_type": "code",
"execution_count": 72,
"execution_count": 81,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -217000,42 +217000,17 @@
" label = (\n",
" r\"$T_{x\\rightarrow y}(x)$\" if forward else r\"$T_{y\\rightarrow x}(y)$\"\n",
" )\n",
" plt.scatter(*x.T, s=200, edgecolors=\"k\", marker=\"o\", label=r\"$x$\")\n",
" plt.scatter(*y.T, s=200, edgecolors=\"k\", marker=\"X\", label=r\"$y$\")\n",
" plt.scatter(*z.T, s=150, edgecolors=\"k\", marker=marker_t, label=label)\n",
" plt.plot(\n",
" jnp.vstack((x[:, 0] if forward else y[:, 0], z[:, 0])),\n",
" jnp.vstack((x[:, 1] if forward else y[:, 1], z[:, 1])),\n",
" c=\"k\",\n",
" )\n",
" plt.scatter(*x.T, s=200, edgecolors=\"k\", marker=\"o\", label=r\"$x$\")\n",
" plt.scatter(*y.T, s=200, edgecolors=\"k\", marker=\"X\", label=r\"$y$\")\n",
" plt.scatter(*z.T, s=150, edgecolors=\"k\", marker=marker_t, label=label)\n",
" plt.legend(fontsize=22)"
]
},
{
"cell_type": "code",
"execution_count": 73,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(13, 2)"
]
},
"execution_count": 73,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"forward = True\n",
"z = dual_potentials.transport(x)\n",
"p = jnp.stack((x[:, 0] if forward else y[:, 0], z[:, 0]), axis=-1)\n",
"q = jnp.stack((x[:, 1] if forward else y[:, 1], z[:, 1]), axis=-1)\n",
"p.shape\n",
"q.shape"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand All @@ -217045,7 +217020,7 @@
},
{
"cell_type": "code",
"execution_count": 74,
"execution_count": 82,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -217056,7 +217031,7 @@
},
{
"cell_type": "code",
"execution_count": 75,
"execution_count": 83,
"metadata": {},
"outputs": [
{
Expand All @@ -217078,7 +217053,7 @@
},
{
"cell_type": "code",
"execution_count": 76,
"execution_count": 84,
"metadata": {},
"outputs": [
{
Expand All @@ -217102,12 +217077,12 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"We can also experiment with that method using a more exotic cost function, here a squared `p=1.1` Norm."
"We can also experiment with that method using a more exotic cost function that penalizes \"diagonal\" moves (here a squared `p=1.1` Norm.), one gets a different type of map."
]
},
{
"cell_type": "code",
"execution_count": 78,
"execution_count": 88,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -217118,7 +217093,7 @@
},
{
"cell_type": "code",
"execution_count": 79,
"execution_count": 89,
"metadata": {},
"outputs": [
{
Expand All @@ -217140,7 +217115,7 @@
},
{
"cell_type": "code",
"execution_count": 80,
"execution_count": 90,
"metadata": {},
"outputs": [
{
Expand All @@ -217159,13 +217134,6 @@
"source": [
"plot_map(x, y, dual_potentials.transport(y, forward=False), forward=False)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
Expand Down

0 comments on commit 9ba0879

Please sign in to comment.