From 9ce642f5de9f83d4437caa170511c70bdce2ac96 Mon Sep 17 00:00:00 2001 From: "pasquale c." <343guiltyspark@outlook.it> Date: Wed, 23 Oct 2024 00:14:37 +0200 Subject: [PATCH] change verbose with verbosity --- Project.toml | 10 - .../tutorials/logit/execute-results/md.json | 2 +- .../src/tutorials/mlp/execute-results/md.json | 2 +- .../tutorials/multi/execute-results/md.json | 2 +- .../regression/execute-results/md.json | 2 +- dev/notebooks/KFAC/sb/Multi.jl.ipynb | 2 +- .../Multi-class_classification.jl.ipynb | 2 +- dev/notebooks/batching/Trials-01-Zygote.ipynb | 6 +- .../batching/Trials-03-Jacobians.ipynb | 2 +- dev/notebooks/batching/regression.ipynb | 2 +- dev/notebooks/ggn/GGN-Julia.ipynb | 2 +- .../multi-class/Multi-Class-Julia-FGD.ipynb | 2 +- .../multi-class/Multi-Class-Julia-SGD.ipynb | 2 +- .../network_subsets/subnetworks_laplace.ipynb | 8 +- docs/src/tutorials/logit.md | 2 +- docs/src/tutorials/logit.qmd | 2 +- docs/src/tutorials/mlp.md | 2 +- docs/src/tutorials/mlp.qmd | 2 +- docs/src/tutorials/multi.md | 2 +- docs/src/tutorials/multi.qmd | 2 +- docs/src/tutorials/regression.md | 2 +- docs/src/tutorials/regression.qmd | 2 +- src/LaplaceRedux.jl | 3 - src/baselaplace/optimize_prior.jl | 4 +- src/mlj_flux.jl | 494 ----------------- test/Manifest.toml | 500 ++++++++++-------- test/Project.toml | 2 +- test/laplace.jl | 8 +- test/mlj_flux_interfacing.jl | 205 ------- test/runtests.jl | 3 - 30 files changed, 305 insertions(+), 976 deletions(-) delete mode 100644 src/mlj_flux.jl delete mode 100644 test/mlj_flux_interfacing.jl diff --git a/Project.toml b/Project.toml index 451430e3..a06a996b 100644 --- a/Project.toml +++ b/Project.toml @@ -6,16 +6,11 @@ version = "1.1.1" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" -ComputationalResources = "ed09eef8-17a6-5b46-8889-db040fac31e3" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" -MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d" -MLJFlux = "094fc8d1-fd35-5302-93ea-dabda2abf845" -MLJModelInterface = "e80e1ace-859a-464e-9ed9-23947d8ae3ea" MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" -ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" @@ -26,16 +21,11 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" Aqua = "0.8" ChainRulesCore = "1.23.0" Compat = "4.7.0" -ComputationalResources = "0.3.2" Distributions = "0.25.109" Flux = "0.12, 0.13, 0.14" LinearAlgebra = "1.7, 1.10" -MLJBase = "1" -MLJFlux = "0.5" -MLJModelInterface = "1.8.0" MLUtils = "0.4" Optimisers = "0.2, 0.3" -ProgressMeter = "1.7.2" Random = "1.9, 1.10" Statistics = "1" Tables = "1.10.1" diff --git a/_freeze/docs/src/tutorials/logit/execute-results/md.json b/_freeze/docs/src/tutorials/logit/execute-results/md.json index 46d49ef3..e4d05f57 100644 --- a/_freeze/docs/src/tutorials/logit/execute-results/md.json +++ b/_freeze/docs/src/tutorials/logit/execute-results/md.json @@ -2,7 +2,7 @@ "hash": "64cc61b7b60f8aef12841a8bd09bc8bb", "result": { "engine": "jupyter", - "markdown": "```@meta\nCurrentModule = LaplaceRedux\n```\n\n# Bayesian Logistic Regression\n\n## Libraries\n\n::: {.cell execution_count=1}\n``` {.julia .cell-code}\nusing Pkg; Pkg.activate(\"docs\")\n# Import libraries\nusing Flux, Plots, TaijaPlotting, Random, Statistics, LaplaceRedux, LinearAlgebra\ntheme(:lime)\n```\n:::\n\n\n## Data\n\nWe will use synthetic data with linearly separable samples:\n\n::: {.cell execution_count=2}\n``` {.julia .cell-code}\n# set seed\nseed= 1234\nRandom.seed!(seed)\n# Number of points to generate.\nxs, ys = LaplaceRedux.Data.toy_data_linear(100; seed=seed)\nX = hcat(xs...) # bring into tabular format\n```\n:::\n\n\nsplit in a training and test set\n\n::: {.cell execution_count=3}\n``` {.julia .cell-code}\n# Shuffle the data\nn = length(ys)\nindices = randperm(n)\n\n# Define the split ratio\nsplit_ratio = 0.8\nsplit_index = Int(floor(split_ratio * n))\n\n# Split the data into training and test sets\ntrain_indices = indices[1:split_index]\ntest_indices = indices[split_index+1:end]\n\nxs_train = xs[train_indices]\nxs_test = xs[test_indices]\nys_train = ys[train_indices]\nys_test = ys[test_indices]\n# bring into tabular format\nX_train = hcat(xs_train...) \nX_test = hcat(xs_test...) \n\ndata = zip(xs_train,ys_train)\n```\n:::\n\n\n## Model\n\nLogistic regression with weight decay can be implemented in Flux.jl as a single dense (linear) layer with binary logit crossentropy loss:\n\n::: {.cell execution_count=4}\n``` {.julia .cell-code}\nnn = Chain(Dense(2,1))\nλ = 0.5\nsqnorm(x) = sum(abs2, x)\nweight_regularization(λ=λ) = 1/2 * λ^2 * sum(sqnorm, Flux.params(nn))\nloss(x, y) = Flux.Losses.logitbinarycrossentropy(nn(x), y) + weight_regularization()\n```\n:::\n\n\nThe code below simply trains the model. After about 50 training epochs training loss stagnates.\n\n::: {.cell execution_count=5}\n``` {.julia .cell-code}\nusing Flux.Optimise: update!, Adam\nopt = Adam()\nepochs = 50\navg_loss(data) = mean(map(d -> loss(d[1],d[2]), data))\nshow_every = epochs/10\n\nfor epoch = 1:epochs\n for d in data\n gs = gradient(Flux.params(nn)) do\n l = loss(d...)\n end\n update!(opt, Flux.params(nn), gs)\n end\n if epoch % show_every == 0\n println(\"Epoch \" * string(epoch))\n @show avg_loss(data)\n end\nend\n```\n:::\n\n\n## Laplace approximation\n\nLaplace approximation for the posterior predictive can be implemented as follows:\n\n::: {.cell execution_count=6}\n``` {.julia .cell-code}\nla = Laplace(nn; likelihood=:classification, λ=λ, subset_of_weights=:last_layer)\nfit!(la, data)\nla_untuned = deepcopy(la) # saving for plotting\noptimize_prior!(la; verbose=true, n_steps=500)\n```\n:::\n\n\nThe plot below shows the resulting posterior predictive surface for the plugin estimator (left) and the Laplace approximation (right).\n\n::: {.cell execution_count=7}\n``` {.julia .cell-code}\nzoom = 0\np_plugin = plot(la, X, ys; title=\"Plugin\", link_approx=:plugin, clim=(0,1))\np_untuned = plot(la_untuned, X, ys; title=\"LA - raw (λ=$(unique(diag(la_untuned.prior.P₀))[1]))\", clim=(0,1), zoom=zoom)\np_laplace = plot(la, X, ys; title=\"LA - tuned (λ=$(round(unique(diag(la.prior.P₀))[1],digits=2)))\", clim=(0,1), zoom=zoom)\nplot(p_plugin, p_untuned, p_laplace, layout=(1,3), size=(1700,400))\n```\n\n::: {.cell-output .cell-output-display execution_count=8}\n![](logit_files/figure-commonmark/cell-8-output-1.svg){}\n:::\n:::\n\n\nNow we can test the level of calibration of the neural network.\nFirst we collect the predicted results over the test dataset\n\n::: {.cell execution_count=8}\n``` {.julia .cell-code}\n predicted_distributions= predict(la, X_test,ret_distr=true)\n```\n\n::: {.cell-output .cell-output-display execution_count=9}\n```\n1×20 Matrix{Distributions.Bernoulli{Float64}}:\n Distributions.Bernoulli{Float64}(p=0.13122) … Distributions.Bernoulli{Float64}(p=0.109559)\n```\n:::\n:::\n\n\nthen we plot the calibration plot\n\n::: {.cell execution_count=9}\n``` {.julia .cell-code}\nCalibration_Plot(la,ys_test,vec(predicted_distributions);n_bins = 10)\n```\n\n::: {.cell-output .cell-output-display}\n![](logit_files/figure-commonmark/cell-10-output-1.svg){}\n:::\n:::\n\n\nas we can see from the plot, although extremely accurate, the neural network does not seem to be calibrated well. This is, however, an effect of the extreme accuracy reached by the neural network which causes the lack of predictions with high uncertainty (low certainty). We can see this by looking at the level of sharpness for the two classes which are extremely close to 1, indicating the high level of trust that the neural network has in the predictions.\n\n::: {.cell execution_count=10}\n``` {.julia .cell-code}\nsharpness_classification(ys_test,vec(predicted_distributions))\n```\n\n::: {.cell-output .cell-output-display execution_count=11}\n```\n(0.9131870336577175, 0.8865055827351365)\n```\n:::\n:::\n\n\n", + "markdown": "```@meta\nCurrentModule = LaplaceRedux\n```\n\n# Bayesian Logistic Regression\n\n## Libraries\n\n::: {.cell execution_count=1}\n``` {.julia .cell-code}\nusing Pkg; Pkg.activate(\"docs\")\n# Import libraries\nusing Flux, Plots, TaijaPlotting, Random, Statistics, LaplaceRedux, LinearAlgebra\ntheme(:lime)\n```\n:::\n\n\n## Data\n\nWe will use synthetic data with linearly separable samples:\n\n::: {.cell execution_count=2}\n``` {.julia .cell-code}\n# set seed\nseed= 1234\nRandom.seed!(seed)\n# Number of points to generate.\nxs, ys = LaplaceRedux.Data.toy_data_linear(100; seed=seed)\nX = hcat(xs...) # bring into tabular format\n```\n:::\n\n\nsplit in a training and test set\n\n::: {.cell execution_count=3}\n``` {.julia .cell-code}\n# Shuffle the data\nn = length(ys)\nindices = randperm(n)\n\n# Define the split ratio\nsplit_ratio = 0.8\nsplit_index = Int(floor(split_ratio * n))\n\n# Split the data into training and test sets\ntrain_indices = indices[1:split_index]\ntest_indices = indices[split_index+1:end]\n\nxs_train = xs[train_indices]\nxs_test = xs[test_indices]\nys_train = ys[train_indices]\nys_test = ys[test_indices]\n# bring into tabular format\nX_train = hcat(xs_train...) \nX_test = hcat(xs_test...) \n\ndata = zip(xs_train,ys_train)\n```\n:::\n\n\n## Model\n\nLogistic regression with weight decay can be implemented in Flux.jl as a single dense (linear) layer with binary logit crossentropy loss:\n\n::: {.cell execution_count=4}\n``` {.julia .cell-code}\nnn = Chain(Dense(2,1))\nλ = 0.5\nsqnorm(x) = sum(abs2, x)\nweight_regularization(λ=λ) = 1/2 * λ^2 * sum(sqnorm, Flux.params(nn))\nloss(x, y) = Flux.Losses.logitbinarycrossentropy(nn(x), y) + weight_regularization()\n```\n:::\n\n\nThe code below simply trains the model. After about 50 training epochs training loss stagnates.\n\n::: {.cell execution_count=5}\n``` {.julia .cell-code}\nusing Flux.Optimise: update!, Adam\nopt = Adam()\nepochs = 50\navg_loss(data) = mean(map(d -> loss(d[1],d[2]), data))\nshow_every = epochs/10\n\nfor epoch = 1:epochs\n for d in data\n gs = gradient(Flux.params(nn)) do\n l = loss(d...)\n end\n update!(opt, Flux.params(nn), gs)\n end\n if epoch % show_every == 0\n println(\"Epoch \" * string(epoch))\n @show avg_loss(data)\n end\nend\n```\n:::\n\n\n## Laplace approximation\n\nLaplace approximation for the posterior predictive can be implemented as follows:\n\n::: {.cell execution_count=6}\n``` {.julia .cell-code}\nla = Laplace(nn; likelihood=:classification, λ=λ, subset_of_weights=:last_layer)\nfit!(la, data)\nla_untuned = deepcopy(la) # saving for plotting\noptimize_prior!(la; verbosity=1, n_steps=500)\n```\n:::\n\n\nThe plot below shows the resulting posterior predictive surface for the plugin estimator (left) and the Laplace approximation (right).\n\n::: {.cell execution_count=7}\n``` {.julia .cell-code}\nzoom = 0\np_plugin = plot(la, X, ys; title=\"Plugin\", link_approx=:plugin, clim=(0,1))\np_untuned = plot(la_untuned, X, ys; title=\"LA - raw (λ=$(unique(diag(la_untuned.prior.P₀))[1]))\", clim=(0,1), zoom=zoom)\np_laplace = plot(la, X, ys; title=\"LA - tuned (λ=$(round(unique(diag(la.prior.P₀))[1],digits=2)))\", clim=(0,1), zoom=zoom)\nplot(p_plugin, p_untuned, p_laplace, layout=(1,3), size=(1700,400))\n```\n\n::: {.cell-output .cell-output-display execution_count=8}\n![](logit_files/figure-commonmark/cell-8-output-1.svg){}\n:::\n:::\n\n\nNow we can test the level of calibration of the neural network.\nFirst we collect the predicted results over the test dataset\n\n::: {.cell execution_count=8}\n``` {.julia .cell-code}\n predicted_distributions= predict(la, X_test,ret_distr=true)\n```\n\n::: {.cell-output .cell-output-display execution_count=9}\n```\n1×20 Matrix{Distributions.Bernoulli{Float64}}:\n Distributions.Bernoulli{Float64}(p=0.13122) … Distributions.Bernoulli{Float64}(p=0.109559)\n```\n:::\n:::\n\n\nthen we plot the calibration plot\n\n::: {.cell execution_count=9}\n``` {.julia .cell-code}\nCalibration_Plot(la,ys_test,vec(predicted_distributions);n_bins = 10)\n```\n\n::: {.cell-output .cell-output-display}\n![](logit_files/figure-commonmark/cell-10-output-1.svg){}\n:::\n:::\n\n\nas we can see from the plot, although extremely accurate, the neural network does not seem to be calibrated well. This is, however, an effect of the extreme accuracy reached by the neural network which causes the lack of predictions with high uncertainty (low certainty). We can see this by looking at the level of sharpness for the two classes which are extremely close to 1, indicating the high level of trust that the neural network has in the predictions.\n\n::: {.cell execution_count=10}\n``` {.julia .cell-code}\nsharpness_classification(ys_test,vec(predicted_distributions))\n```\n\n::: {.cell-output .cell-output-display execution_count=11}\n```\n(0.9131870336577175, 0.8865055827351365)\n```\n:::\n:::\n\n\n", "supporting": [ "logit_files" ], diff --git a/_freeze/docs/src/tutorials/mlp/execute-results/md.json b/_freeze/docs/src/tutorials/mlp/execute-results/md.json index 54d40aed..5a0197e2 100644 --- a/_freeze/docs/src/tutorials/mlp/execute-results/md.json +++ b/_freeze/docs/src/tutorials/mlp/execute-results/md.json @@ -2,7 +2,7 @@ "hash": "d070805f89710f2cc874e27ecc1a1c2b", "result": { "engine": "jupyter", - "markdown": "```@meta\nCurrentModule = LaplaceRedux\n```\n\n# Bayesian MLP\n\n## Libraries\n\n::: {.cell execution_count=1}\n``` {.julia .cell-code}\nusing Pkg; Pkg.activate(\"docs\")\n# Import libraries\nusing Flux, Plots, TaijaPlotting, Random, Statistics, LaplaceRedux, LinearAlgebra\ntheme(:lime)\n```\n:::\n\n\n## Data\n\nThis time we use a synthetic dataset containing samples that are not linearly separable:\n\n::: {.cell execution_count=2}\n``` {.julia .cell-code}\n#set seed\nseed = 1234\nRandom.seed!(seed)\n# Number of points to generate.\nxs, ys = LaplaceRedux.Data.toy_data_non_linear(400; seed = seed)\n# Shuffle the data\nn = length(ys)\nindices = randperm(n)\n\n# Define the split ratio\nsplit_ratio = 0.8\nsplit_index = Int(floor(split_ratio * n))\n\n# Split the data into training and test sets\ntrain_indices = indices[1:split_index]\ntest_indices = indices[split_index+1:end]\n\nxs_train = xs[train_indices]\nxs_test = xs[test_indices]\nys_train = ys[train_indices]\nys_test = ys[test_indices]\n# bring into tabular format\nX_train = hcat(xs_train...) \nX_test = hcat(xs_test...) \n\ndata = zip(xs_train,ys_train)\n```\n:::\n\n\n## Model\nFor the classification task we build a neural network with weight decay composed of a single hidden layer.\n\n::: {.cell execution_count=3}\n``` {.julia .cell-code}\nn_hidden = 10\nD = size(X_train,1)\nnn = Chain(\n Dense(D, n_hidden, σ),\n Dense(n_hidden, 1)\n) \nloss(x, y) = Flux.Losses.logitbinarycrossentropy(nn(x), y) \n```\n:::\n\n\nThe model is trained until training loss stagnates.\n\n::: {.cell execution_count=4}\n``` {.julia .cell-code}\nusing Flux.Optimise: update!, Adam\nopt = Adam(1e-3)\nepochs = 100\navg_loss(data) = mean(map(d -> loss(d[1],d[2]), data))\nshow_every = epochs/10\n\nfor epoch = 1:epochs\n for d in data\n gs = gradient(Flux.params(nn)) do\n l = loss(d...)\n end\n update!(opt, Flux.params(nn), gs)\n end\n if epoch % show_every == 0\n println(\"Epoch \" * string(epoch))\n @show avg_loss(data)\n end\nend\n```\n:::\n\n\n## Laplace Approximation\n\nLaplace approximation can be implemented as follows:\n\n::: {.cell execution_count=5}\n``` {.julia .cell-code}\nla = Laplace(nn; likelihood=:classification, subset_of_weights=:all)\nfit!(la, data)\nla_untuned = deepcopy(la) # saving for plotting\noptimize_prior!(la; verbose=true, n_steps=500)\n```\n:::\n\n\nThe plot below shows the resulting posterior predictive surface for the plugin estimator (left) and the Laplace approximation (right).\n\n::: {.cell execution_count=6}\n``` {.julia .cell-code}\n# Plot the posterior distribution with a contour plot.\nzoom=0\np_plugin = plot(la, X_train, ys_train; title=\"Plugin\", link_approx=:plugin, clim=(0,1))\np_untuned = plot(la_untuned, X_train, ys_train; title=\"LA - raw (λ=$(unique(diag(la_untuned.prior.P₀))[1]))\", clim=(0,1), zoom=zoom)\np_laplace = plot(la, X_train, ys_train; title=\"LA - tuned (λ=$(round(unique(diag(la.prior.P₀))[1],digits=2)))\", clim=(0,1), zoom=zoom)\nplot(p_plugin, p_untuned, p_laplace, layout=(1,3), size=(1700,400))\n```\n\n::: {.cell-output .cell-output-display execution_count=7}\n![](mlp_files/figure-commonmark/cell-7-output-1.svg){}\n:::\n:::\n\n\nZooming out we can note that the plugin estimator produces high-confidence estimates in regions scarce of any samples. The Laplace approximation is much more conservative about these regions.\n\n::: {.cell execution_count=7}\n``` {.julia .cell-code}\nzoom=-50\np_plugin = plot(la, X_train, ys_train; title=\"Plugin\", link_approx=:plugin, clim=(0,1))\np_untuned = plot(la_untuned, X_train, ys_train; title=\"LA - raw (λ=$(unique(diag(la_untuned.prior.P₀))[1]))\", clim=(0,1), zoom=zoom)\np_laplace = plot(la, X_train, ys_train; title=\"LA - tuned (λ=$(round(unique(diag(la.prior.P₀))[1],digits=2)))\", clim=(0,1), zoom=zoom)\nplot(p_plugin, p_untuned, p_laplace, layout=(1,3), size=(1700,400))\n```\n\n::: {.cell-output .cell-output-display execution_count=8}\n![](mlp_files/figure-commonmark/cell-8-output-1.svg){}\n:::\n:::\n\n\nWe plot now the calibration plot to assess the level of average calibration reached by the neural network.\n\n::: {.cell execution_count=8}\n``` {.julia .cell-code}\npredicted_distributions= predict(la, X_test,ret_distr=true)\nCalibration_Plot(la,ys_test,vec(predicted_distributions);n_bins = 10)\n```\n\n::: {.cell-output .cell-output-display}\n![](mlp_files/figure-commonmark/cell-9-output-1.svg){}\n:::\n:::\n\n\nand the sharpness score\n\n::: {.cell execution_count=9}\n``` {.julia .cell-code}\nsharpness_classification(ys_test,vec(predicted_distributions))\n```\n\n::: {.cell-output .cell-output-display execution_count=10}\n```\n(0.9277189055456709, 0.9196132560599691)\n```\n:::\n:::\n\n\n", + "markdown": "```@meta\nCurrentModule = LaplaceRedux\n```\n\n# Bayesian MLP\n\n## Libraries\n\n::: {.cell execution_count=1}\n``` {.julia .cell-code}\nusing Pkg; Pkg.activate(\"docs\")\n# Import libraries\nusing Flux, Plots, TaijaPlotting, Random, Statistics, LaplaceRedux, LinearAlgebra\ntheme(:lime)\n```\n:::\n\n\n## Data\n\nThis time we use a synthetic dataset containing samples that are not linearly separable:\n\n::: {.cell execution_count=2}\n``` {.julia .cell-code}\n#set seed\nseed = 1234\nRandom.seed!(seed)\n# Number of points to generate.\nxs, ys = LaplaceRedux.Data.toy_data_non_linear(400; seed = seed)\n# Shuffle the data\nn = length(ys)\nindices = randperm(n)\n\n# Define the split ratio\nsplit_ratio = 0.8\nsplit_index = Int(floor(split_ratio * n))\n\n# Split the data into training and test sets\ntrain_indices = indices[1:split_index]\ntest_indices = indices[split_index+1:end]\n\nxs_train = xs[train_indices]\nxs_test = xs[test_indices]\nys_train = ys[train_indices]\nys_test = ys[test_indices]\n# bring into tabular format\nX_train = hcat(xs_train...) \nX_test = hcat(xs_test...) \n\ndata = zip(xs_train,ys_train)\n```\n:::\n\n\n## Model\nFor the classification task we build a neural network with weight decay composed of a single hidden layer.\n\n::: {.cell execution_count=3}\n``` {.julia .cell-code}\nn_hidden = 10\nD = size(X_train,1)\nnn = Chain(\n Dense(D, n_hidden, σ),\n Dense(n_hidden, 1)\n) \nloss(x, y) = Flux.Losses.logitbinarycrossentropy(nn(x), y) \n```\n:::\n\n\nThe model is trained until training loss stagnates.\n\n::: {.cell execution_count=4}\n``` {.julia .cell-code}\nusing Flux.Optimise: update!, Adam\nopt = Adam(1e-3)\nepochs = 100\navg_loss(data) = mean(map(d -> loss(d[1],d[2]), data))\nshow_every = epochs/10\n\nfor epoch = 1:epochs\n for d in data\n gs = gradient(Flux.params(nn)) do\n l = loss(d...)\n end\n update!(opt, Flux.params(nn), gs)\n end\n if epoch % show_every == 0\n println(\"Epoch \" * string(epoch))\n @show avg_loss(data)\n end\nend\n```\n:::\n\n\n## Laplace Approximation\n\nLaplace approximation can be implemented as follows:\n\n::: {.cell execution_count=5}\n``` {.julia .cell-code}\nla = Laplace(nn; likelihood=:classification, subset_of_weights=:all)\nfit!(la, data)\nla_untuned = deepcopy(la) # saving for plotting\noptimize_prior!(la; verbosity=1, n_steps=500)\n```\n:::\n\n\nThe plot below shows the resulting posterior predictive surface for the plugin estimator (left) and the Laplace approximation (right).\n\n::: {.cell execution_count=6}\n``` {.julia .cell-code}\n# Plot the posterior distribution with a contour plot.\nzoom=0\np_plugin = plot(la, X_train, ys_train; title=\"Plugin\", link_approx=:plugin, clim=(0,1))\np_untuned = plot(la_untuned, X_train, ys_train; title=\"LA - raw (λ=$(unique(diag(la_untuned.prior.P₀))[1]))\", clim=(0,1), zoom=zoom)\np_laplace = plot(la, X_train, ys_train; title=\"LA - tuned (λ=$(round(unique(diag(la.prior.P₀))[1],digits=2)))\", clim=(0,1), zoom=zoom)\nplot(p_plugin, p_untuned, p_laplace, layout=(1,3), size=(1700,400))\n```\n\n::: {.cell-output .cell-output-display execution_count=7}\n![](mlp_files/figure-commonmark/cell-7-output-1.svg){}\n:::\n:::\n\n\nZooming out we can note that the plugin estimator produces high-confidence estimates in regions scarce of any samples. The Laplace approximation is much more conservative about these regions.\n\n::: {.cell execution_count=7}\n``` {.julia .cell-code}\nzoom=-50\np_plugin = plot(la, X_train, ys_train; title=\"Plugin\", link_approx=:plugin, clim=(0,1))\np_untuned = plot(la_untuned, X_train, ys_train; title=\"LA - raw (λ=$(unique(diag(la_untuned.prior.P₀))[1]))\", clim=(0,1), zoom=zoom)\np_laplace = plot(la, X_train, ys_train; title=\"LA - tuned (λ=$(round(unique(diag(la.prior.P₀))[1],digits=2)))\", clim=(0,1), zoom=zoom)\nplot(p_plugin, p_untuned, p_laplace, layout=(1,3), size=(1700,400))\n```\n\n::: {.cell-output .cell-output-display execution_count=8}\n![](mlp_files/figure-commonmark/cell-8-output-1.svg){}\n:::\n:::\n\n\nWe plot now the calibration plot to assess the level of average calibration reached by the neural network.\n\n::: {.cell execution_count=8}\n``` {.julia .cell-code}\npredicted_distributions= predict(la, X_test,ret_distr=true)\nCalibration_Plot(la,ys_test,vec(predicted_distributions);n_bins = 10)\n```\n\n::: {.cell-output .cell-output-display}\n![](mlp_files/figure-commonmark/cell-9-output-1.svg){}\n:::\n:::\n\n\nand the sharpness score\n\n::: {.cell execution_count=9}\n``` {.julia .cell-code}\nsharpness_classification(ys_test,vec(predicted_distributions))\n```\n\n::: {.cell-output .cell-output-display execution_count=10}\n```\n(0.9277189055456709, 0.9196132560599691)\n```\n:::\n:::\n\n\n", "supporting": [ "mlp_files\\figure-commonmark" ], diff --git a/_freeze/docs/src/tutorials/multi/execute-results/md.json b/_freeze/docs/src/tutorials/multi/execute-results/md.json index b5fcbc2c..757528bb 100644 --- a/_freeze/docs/src/tutorials/multi/execute-results/md.json +++ b/_freeze/docs/src/tutorials/multi/execute-results/md.json @@ -2,7 +2,7 @@ "hash": "ecc8efc3eca6cfcff72dc1b10acc73ff", "result": { "engine": "jupyter", - "markdown": "---\ntitle: Multi-class problem\n---\n\n\n\n## Libraries\n\n\n::: {.cell execution_count=1}\n``` {.julia .cell-code}\nusing Pkg; Pkg.activate(\"docs\")\n# Import libraries\nusing Flux, Plots, TaijaPlotting, Random, Statistics, LaplaceRedux\ntheme(:lime)\n```\n:::\n\n\n## Data\n\n::: {.cell execution_count=2}\n``` {.julia .cell-code}\nusing LaplaceRedux.Data\nseed = 1234\nx, y = Data.toy_data_multi(seed=seed)\nX = hcat(x...)\ny_onehot = Flux.onehotbatch(y, unique(y))\ny_onehot = Flux.unstack(y_onehot',1)\n```\n:::\n\n\nsplit in training and test datasets\n\n::: {.cell execution_count=3}\n``` {.julia .cell-code}\n# Shuffle the data\nRandom.seed!(seed)\nn = length(y)\nindices = randperm(n)\n\n# Define the split ratio\nsplit_ratio = 0.8\nsplit_index = Int(floor(split_ratio * n))\n\n# Split the data into training and test sets\ntrain_indices = indices[1:split_index]\ntest_indices = indices[split_index+1:end]\n\nx_train = x[train_indices]\nx_test = x[test_indices]\ny_onehot_train = y_onehot[train_indices,:]\ny_onehot_test = y_onehot[test_indices,:]\n\ny_train = vec(y[train_indices,:])\ny_test = vec(y[test_indices,:])\n# bring into tabular format\nX_train = hcat(x_train...) \nX_test = hcat(x_test...) \n\ndata = zip(x_train,y_onehot_train)\n#data = zip(x,y_onehot)\n```\n:::\n\n\n## MLP\n\nWe set up a model\n\n::: {.cell execution_count=4}\n``` {.julia .cell-code}\nn_hidden = 3\nD = size(X,1)\nout_dim = length(unique(y))\nnn = Chain(\n Dense(D, n_hidden, σ),\n Dense(n_hidden, out_dim)\n) \nloss(x, y) = Flux.Losses.logitcrossentropy(nn(x), y)\n```\n:::\n\n\ntraining:\n\n::: {.cell execution_count=5}\n``` {.julia .cell-code}\nusing Flux.Optimise: update!, Adam\nopt = Adam()\nepochs = 100\navg_loss(data) = mean(map(d -> loss(d[1],d[2]), data))\nshow_every = epochs/10\n\nfor epoch = 1:epochs\n for d in data\n gs = gradient(Flux.params(nn)) do\n l = loss(d...)\n end\n update!(opt, Flux.params(nn), gs)\n end\n if epoch % show_every == 0\n println(\"Epoch \" * string(epoch))\n @show avg_loss(data)\n end\nend\n```\n:::\n\n\n## Laplace Approximation\n\nThe Laplace approximation can be implemented as follows:\n\n::: {.cell execution_count=6}\n``` {.julia .cell-code}\nla = Laplace(nn; likelihood=:classification)\nfit!(la, data)\noptimize_prior!(la; verbose=true, n_steps=100)\n```\n:::\n\n\nwith either the probit approximation:\n\n::: {.cell execution_count=7}\n``` {.julia .cell-code}\n_labels = sort(unique(y))\nplt_list = []\nfor target in _labels\n plt = plot(la, X_test, y_test; target=target, clim=(0,1))\n push!(plt_list, plt)\nend\nplot(plt_list...)\n```\n\n::: {.cell-output .cell-output-display execution_count=8}\n![](multi_files/figure-commonmark/cell-8-output-1.svg){}\n:::\n:::\n\n\n or the plugin approximation:\n\n::: {.cell execution_count=8}\n``` {.julia .cell-code}\n_labels = sort(unique(y))\nplt_list = []\nfor target in _labels\n plt = plot(la, X_test, y_test; target=target, clim=(0,1), link_approx=:plugin)\n push!(plt_list, plt)\nend\nplot(plt_list...)\n```\n\n::: {.cell-output .cell-output-display execution_count=9}\n![](multi_files/figure-commonmark/cell-9-output-1.svg){}\n:::\n:::\n\n\n## Calibration Plots\n\nIn the case of multiclass classification tasks, we cannot plot the calibration plots directly since they can only be used in the binary classification case. However, we can use them to plot the calibration of the predictions for 1 class against all the others. To do so, we first have to collect the predicted categorical distributions\n\n::: {.cell execution_count=9}\n``` {.julia .cell-code}\npredicted_distributions= predict(la, X_test,ret_distr=true)\n```\n\n::: {.cell-output .cell-output-display execution_count=10}\n```\n1×20 Matrix{Distributions.Categorical{Float64, Vector{Float64}}}:\n Distributions.Categorical{Float64, Vector{Float64}}(support=Base.OneTo(4), p=[0.0569184, 0.196066, 0.0296796, 0.717336]) … Distributions.Categorical{Float64, Vector{Float64}}(support=Base.OneTo(4), p=[0.0569634, 0.195727, 0.0296449, 0.717665])\n```\n:::\n:::\n\n\nthen we transform the categorical distributions into Bernoulli distributions by taking only the probability of the class of interest, for example the third one. \n\n::: {.cell execution_count=10}\n``` {.julia .cell-code}\nusing Distributions\nbernoulli_distributions = [Bernoulli(p.p[3]) for p in vec(predicted_distributions)]\n```\n\n::: {.cell-output .cell-output-display execution_count=11}\n```\n20-element Vector{Bernoulli{Float64}}:\n Bernoulli{Float64}(p=0.029679590887034743)\n Bernoulli{Float64}(p=0.6682373773598078)\n Bernoulli{Float64}(p=0.20912995228011141)\n Bernoulli{Float64}(p=0.20913322913224044)\n Bernoulli{Float64}(p=0.02971989045895732)\n Bernoulli{Float64}(p=0.668431087463204)\n Bernoulli{Float64}(p=0.03311710703617972)\n Bernoulli{Float64}(p=0.20912981531862682)\n Bernoulli{Float64}(p=0.11273726979027407)\n Bernoulli{Float64}(p=0.2490744632745955)\n Bernoulli{Float64}(p=0.029886357844211404)\n Bernoulli{Float64}(p=0.02965323602487074)\n Bernoulli{Float64}(p=0.1126799374664026)\n Bernoulli{Float64}(p=0.11278538625980777)\n Bernoulli{Float64}(p=0.6683139127616431)\n Bernoulli{Float64}(p=0.029644435143197145)\n Bernoulli{Float64}(p=0.11324691083703237)\n Bernoulli{Float64}(p=0.6681422555922787)\n Bernoulli{Float64}(p=0.668424345470233)\n Bernoulli{Float64}(p=0.029644891255330787)\n```\n:::\n:::\n\n\nNow we can use ```Calibration_Plot``` to see the level of calibration of the neural network\n\n::: {.cell execution_count=11}\n``` {.julia .cell-code}\nplt = Calibration_Plot(la,hcat(y_onehot_test...)[3,:],bernoulli_distributions;n_bins = 10);\n```\n\n::: {.cell-output .cell-output-display}\n![](multi_files/figure-commonmark/cell-12-output-1.svg){}\n:::\n:::\n\n\nThe plot is peaked around 0.7.\n\nA possible reason is that class 3 is relatively easy for the model to identify from the other classes, although it remains a bit underconfident in its predictions. \nAnother reason for the peak may be the lack of cases where the predicted probability is lower (e.g., around 0.5), which could indicate that the network has not encountered ambiguous or difficult-to-classify examples for such class. This once again might be because either class 3 has distinct features that the model can easily learn, leading to fewer uncertain predictions, or is a consequence of the limited dataset. \n\n We can measure how sharp the neural network is by computing the sharpness score\n\nsharpness_classification(hcat(y_onehot_test...)[3,:],vec(bernoulli_distributions))\n\n```\n\n\nThe neural network seems to be able to correctly classify the majority of samples not belonging to class 3 with a relative high confidence, but remains more uncertain when he encounter examples belonging to class 3.\n\n", + "markdown": "---\ntitle: Multi-class problem\n---\n\n\n\n## Libraries\n\n\n::: {.cell execution_count=1}\n``` {.julia .cell-code}\nusing Pkg; Pkg.activate(\"docs\")\n# Import libraries\nusing Flux, Plots, TaijaPlotting, Random, Statistics, LaplaceRedux\ntheme(:lime)\n```\n:::\n\n\n## Data\n\n::: {.cell execution_count=2}\n``` {.julia .cell-code}\nusing LaplaceRedux.Data\nseed = 1234\nx, y = Data.toy_data_multi(seed=seed)\nX = hcat(x...)\ny_onehot = Flux.onehotbatch(y, unique(y))\ny_onehot = Flux.unstack(y_onehot',1)\n```\n:::\n\n\nsplit in training and test datasets\n\n::: {.cell execution_count=3}\n``` {.julia .cell-code}\n# Shuffle the data\nRandom.seed!(seed)\nn = length(y)\nindices = randperm(n)\n\n# Define the split ratio\nsplit_ratio = 0.8\nsplit_index = Int(floor(split_ratio * n))\n\n# Split the data into training and test sets\ntrain_indices = indices[1:split_index]\ntest_indices = indices[split_index+1:end]\n\nx_train = x[train_indices]\nx_test = x[test_indices]\ny_onehot_train = y_onehot[train_indices,:]\ny_onehot_test = y_onehot[test_indices,:]\n\ny_train = vec(y[train_indices,:])\ny_test = vec(y[test_indices,:])\n# bring into tabular format\nX_train = hcat(x_train...) \nX_test = hcat(x_test...) \n\ndata = zip(x_train,y_onehot_train)\n#data = zip(x,y_onehot)\n```\n:::\n\n\n## MLP\n\nWe set up a model\n\n::: {.cell execution_count=4}\n``` {.julia .cell-code}\nn_hidden = 3\nD = size(X,1)\nout_dim = length(unique(y))\nnn = Chain(\n Dense(D, n_hidden, σ),\n Dense(n_hidden, out_dim)\n) \nloss(x, y) = Flux.Losses.logitcrossentropy(nn(x), y)\n```\n:::\n\n\ntraining:\n\n::: {.cell execution_count=5}\n``` {.julia .cell-code}\nusing Flux.Optimise: update!, Adam\nopt = Adam()\nepochs = 100\navg_loss(data) = mean(map(d -> loss(d[1],d[2]), data))\nshow_every = epochs/10\n\nfor epoch = 1:epochs\n for d in data\n gs = gradient(Flux.params(nn)) do\n l = loss(d...)\n end\n update!(opt, Flux.params(nn), gs)\n end\n if epoch % show_every == 0\n println(\"Epoch \" * string(epoch))\n @show avg_loss(data)\n end\nend\n```\n:::\n\n\n## Laplace Approximation\n\nThe Laplace approximation can be implemented as follows:\n\n::: {.cell execution_count=6}\n``` {.julia .cell-code}\nla = Laplace(nn; likelihood=:classification)\nfit!(la, data)\noptimize_prior!(la; verbosity=1, n_steps=100)\n```\n:::\n\n\nwith either the probit approximation:\n\n::: {.cell execution_count=7}\n``` {.julia .cell-code}\n_labels = sort(unique(y))\nplt_list = []\nfor target in _labels\n plt = plot(la, X_test, y_test; target=target, clim=(0,1))\n push!(plt_list, plt)\nend\nplot(plt_list...)\n```\n\n::: {.cell-output .cell-output-display execution_count=8}\n![](multi_files/figure-commonmark/cell-8-output-1.svg){}\n:::\n:::\n\n\n or the plugin approximation:\n\n::: {.cell execution_count=8}\n``` {.julia .cell-code}\n_labels = sort(unique(y))\nplt_list = []\nfor target in _labels\n plt = plot(la, X_test, y_test; target=target, clim=(0,1), link_approx=:plugin)\n push!(plt_list, plt)\nend\nplot(plt_list...)\n```\n\n::: {.cell-output .cell-output-display execution_count=9}\n![](multi_files/figure-commonmark/cell-9-output-1.svg){}\n:::\n:::\n\n\n## Calibration Plots\n\nIn the case of multiclass classification tasks, we cannot plot the calibration plots directly since they can only be used in the binary classification case. However, we can use them to plot the calibration of the predictions for 1 class against all the others. To do so, we first have to collect the predicted categorical distributions\n\n::: {.cell execution_count=9}\n``` {.julia .cell-code}\npredicted_distributions= predict(la, X_test,ret_distr=true)\n```\n\n::: {.cell-output .cell-output-display execution_count=10}\n```\n1×20 Matrix{Distributions.Categorical{Float64, Vector{Float64}}}:\n Distributions.Categorical{Float64, Vector{Float64}}(support=Base.OneTo(4), p=[0.0569184, 0.196066, 0.0296796, 0.717336]) … Distributions.Categorical{Float64, Vector{Float64}}(support=Base.OneTo(4), p=[0.0569634, 0.195727, 0.0296449, 0.717665])\n```\n:::\n:::\n\n\nthen we transform the categorical distributions into Bernoulli distributions by taking only the probability of the class of interest, for example the third one. \n\n::: {.cell execution_count=10}\n``` {.julia .cell-code}\nusing Distributions\nbernoulli_distributions = [Bernoulli(p.p[3]) for p in vec(predicted_distributions)]\n```\n\n::: {.cell-output .cell-output-display execution_count=11}\n```\n20-element Vector{Bernoulli{Float64}}:\n Bernoulli{Float64}(p=0.029679590887034743)\n Bernoulli{Float64}(p=0.6682373773598078)\n Bernoulli{Float64}(p=0.20912995228011141)\n Bernoulli{Float64}(p=0.20913322913224044)\n Bernoulli{Float64}(p=0.02971989045895732)\n Bernoulli{Float64}(p=0.668431087463204)\n Bernoulli{Float64}(p=0.03311710703617972)\n Bernoulli{Float64}(p=0.20912981531862682)\n Bernoulli{Float64}(p=0.11273726979027407)\n Bernoulli{Float64}(p=0.2490744632745955)\n Bernoulli{Float64}(p=0.029886357844211404)\n Bernoulli{Float64}(p=0.02965323602487074)\n Bernoulli{Float64}(p=0.1126799374664026)\n Bernoulli{Float64}(p=0.11278538625980777)\n Bernoulli{Float64}(p=0.6683139127616431)\n Bernoulli{Float64}(p=0.029644435143197145)\n Bernoulli{Float64}(p=0.11324691083703237)\n Bernoulli{Float64}(p=0.6681422555922787)\n Bernoulli{Float64}(p=0.668424345470233)\n Bernoulli{Float64}(p=0.029644891255330787)\n```\n:::\n:::\n\n\nNow we can use ```Calibration_Plot``` to see the level of calibration of the neural network\n\n::: {.cell execution_count=11}\n``` {.julia .cell-code}\nplt = Calibration_Plot(la,hcat(y_onehot_test...)[3,:],bernoulli_distributions;n_bins = 10);\n```\n\n::: {.cell-output .cell-output-display}\n![](multi_files/figure-commonmark/cell-12-output-1.svg){}\n:::\n:::\n\n\nThe plot is peaked around 0.7.\n\nA possible reason is that class 3 is relatively easy for the model to identify from the other classes, although it remains a bit underconfident in its predictions. \nAnother reason for the peak may be the lack of cases where the predicted probability is lower (e.g., around 0.5), which could indicate that the network has not encountered ambiguous or difficult-to-classify examples for such class. This once again might be because either class 3 has distinct features that the model can easily learn, leading to fewer uncertain predictions, or is a consequence of the limited dataset. \n\n We can measure how sharp the neural network is by computing the sharpness score\n\nsharpness_classification(hcat(y_onehot_test...)[3,:],vec(bernoulli_distributions))\n\n```\n\n\nThe neural network seems to be able to correctly classify the majority of samples not belonging to class 3 with a relative high confidence, but remains more uncertain when he encounter examples belonging to class 3.\n\n", "supporting": [ "multi_files\\figure-commonmark" ], diff --git a/_freeze/docs/src/tutorials/regression/execute-results/md.json b/_freeze/docs/src/tutorials/regression/execute-results/md.json index dfc91d38..b5a1ed6f 100644 --- a/_freeze/docs/src/tutorials/regression/execute-results/md.json +++ b/_freeze/docs/src/tutorials/regression/execute-results/md.json @@ -2,7 +2,7 @@ "hash": "5a4b0f727b0fa06b8f883c934a123434", "result": { "engine": "jupyter", - "markdown": "```@meta\nCurrentModule = LaplaceRedux\n```\n## Libraries\n\nImport the libraries required to run this example\n\n::: {.cell execution_count=1}\n``` {.julia .cell-code}\nusing Pkg; Pkg.activate(\"docs\")\n# Import libraries\nusing Flux, Plots, TaijaPlotting, Random, Statistics, LaplaceRedux\ntheme(:wong)\n```\n:::\n\n\n## Data\n\nWe first generate some synthetic data:\n\n::: {.cell execution_count=2}\n``` {.julia .cell-code}\nusing LaplaceRedux.Data\nn = 3000 # number of observations\nσtrue = 0.30 # true observational noise\nx, y = Data.toy_data_regression(n;noise=σtrue,seed=1234)\nxs = [[x] for x in x]\nX = permutedims(x)\n```\n:::\n\n\nand split them in a training set and a test set\n\n::: {.cell execution_count=3}\n``` {.julia .cell-code}\n# Shuffle the data\nRandom.seed!(1234) # Set a seed for reproducibility\nshuffle_indices = shuffle(1:n)\n\n# Define split ratios\ntrain_ratio = 0.8\ntest_ratio = 0.2\n\n# Calculate split indices\ntrain_end = Int(floor(train_ratio * n))\n\n# Split the data\ntrain_indices = shuffle_indices[1:train_end]\ntest_indices = shuffle_indices[train_end+1:end]\n\n# Create the splits\nx_train, y_train = x[train_indices], y[train_indices]\nx_test, y_test = x[test_indices], y[test_indices]\n\n# Optional: Convert to desired format\nxs_train = [[x] for x in x_train]\nxs_test = [[x] for x in x_test]\nX_train = permutedims(x_train)\nX_test = permutedims(x_test)\n```\n:::\n\n\n## MLP\n\nWe set up a model and loss with weight regularization:\n\n::: {.cell execution_count=4}\n``` {.julia .cell-code}\ntrain_data = zip(xs_train,y_train)\nn_hidden = 50\nD = size(X,1)\nnn = Chain(\n Dense(D, n_hidden, tanh),\n Dense(n_hidden, 1)\n) \nloss(x, y) = Flux.Losses.mse(nn(x), y)\n```\n:::\n\n\nWe train the model:\n\n::: {.cell execution_count=5}\n``` {.julia .cell-code}\nusing Flux.Optimise: update!, Adam\nopt = Adam(1e-3)\nepochs = 1000\navg_loss(train_data) = mean(map(d -> loss(d[1],d[2]), train_data))\nshow_every = epochs/10\n\nfor epoch = 1:epochs\n for d in train_data\n gs = gradient(Flux.params(nn)) do\n l = loss(d...)\n end\n update!(opt, Flux.params(nn), gs)\n end\n if epoch % show_every == 0\n println(\"Epoch \" * string(epoch))\n @show avg_loss(train_data)\n end\nend\n```\n:::\n\n\n## Laplace Approximation\n\nLaplace approximation can be implemented as follows:\n\n::: {.cell execution_count=6}\n``` {.julia .cell-code}\nsubset_w = :all\nla = Laplace(nn; likelihood=:regression, subset_of_weights=subset_w)\nfit!(la, train_data)\nplot(la, X_train, y_train; zoom=-5, size=(400,400))\n```\n\n::: {.cell-output .cell-output-display execution_count=7}\n![](regression_files/figure-commonmark/cell-7-output-1.svg){}\n:::\n:::\n\n\nNext we optimize the prior precision $P_0$ and and observational noise $\\sigma$ using Empirical Bayes:\n\n::: {.cell execution_count=7}\n``` {.julia .cell-code}\noptimize_prior!(la; verbose=true)\nplot(la, X_train, y_train; zoom=-5, size=(400,400))\n```\n\n::: {.cell-output .cell-output-stdout}\n```\nloss(exp.(logP₀), exp.(logσ)) = 668.3714946472106\nLog likelihood: -618.5175117610522\nLog det ratio: 68.76532606873238\nScatter: 30.942639703584522\nloss(exp.(logP₀), exp.(logσ)) = 719.2536119935747\nLog likelihood: -673.0996963447847\nLog det ratio: 76.53255037599948\nScatter: 15.775280921580569\nloss(exp.(logP₀), exp.(logσ)) = 574.605864472924\nLog likelihood: -528.694286608232\n```\n:::\n\n::: {.cell-output .cell-output-stdout}\n```\n\nLog det ratio: 80.73114330857285\nScatter: 11.092012420811196\nloss(exp.(logP₀), exp.(logσ)) = 568.4433850825203\nLog likelihood: -522.4407550111031\nLog det ratio: 82.10089958560243\nScatter: 9.90436055723207\n```\n:::\n\n::: {.cell-output .cell-output-stdout}\n```\n\nloss(exp.(logP₀), exp.(logσ)) = 566.9485255672008\nLog likelihood: -520.9682443835385\nLog det ratio: 81.84516297272847\nScatter: 10.11539939459612\nloss(exp.(logP₀), exp.(logσ)) = 559.9852101992792\nLog likelihood: -514.0625630685765\nLog det ratio: 80.97813304453496\nScatter: 10.867161216870441\n```\n:::\n\n::: {.cell-output .cell-output-stdout}\n```\nloss(exp.(logP₀), exp.(logσ)) = 559.1404593114019\nLog likelihood: -513.2449017869876\nLog det ratio: 80.16026747795866\nScatter: 11.630847570869795\nloss(exp.(logP₀), exp.(logσ)) = 559.3201392562346\nLog likelihood: -513.4273312363501\nLog det ratio: 79.68892769076004\nScatter: 12.096688349008877\n```\n:::\n\n::: {.cell-output .cell-output-stdout}\n```\n\nloss(exp.(logP₀), exp.(logσ)) = 559.2111983983311\nLog likelihood: -513.3174948065804\nLog det ratio: 79.56631681347287\nScatter: 12.2210903700287\nloss(exp.(logP₀), exp.(logσ)) = 559.1107459310829\nLog likelihood: -513.2176579845662\nLog det ratio: 79.63946732368183\nScatter: 12.146708569351494\n```\n:::\n\n::: {.cell-output .cell-output-display execution_count=8}\n![](regression_files/figure-commonmark/cell-8-output-6.svg){}\n:::\n:::\n\n\n## Calibration Plot\nOnce the prior precision has been optimized it is possible to evaluate the quality of the predictive distribution \nobtained through a calibration plot and a test dataset (y_test, X_test). \n\nFirst, we apply the trained network on the test dataset (y_test, X_test) and collect the neural network's predicted distributions \n\n::: {.cell execution_count=8}\n``` {.julia .cell-code}\npredicted_distributions= predict(la, X_test,ret_distr=true)\n```\n\n::: {.cell-output .cell-output-display execution_count=9}\n```\n600×1 Matrix{Distributions.Normal{Float64}}:\n Distributions.Normal{Float64}(μ=-0.1137533187866211, σ=0.07161056521032018)\n Distributions.Normal{Float64}(μ=0.7063850164413452, σ=0.050697938829269665)\n Distributions.Normal{Float64}(μ=-0.2211049497127533, σ=0.06876939416479119)\n Distributions.Normal{Float64}(μ=0.720299243927002, σ=0.08665125572287981)\n Distributions.Normal{Float64}(μ=-0.8338974714279175, σ=0.06464012115237727)\n Distributions.Normal{Float64}(μ=0.9910320043563843, σ=0.07452060172164382)\n Distributions.Normal{Float64}(μ=0.1507074236869812, σ=0.07316299850461126)\n Distributions.Normal{Float64}(μ=0.20875799655914307, σ=0.05507748397231652)\n Distributions.Normal{Float64}(μ=0.973572850227356, σ=0.07899004963915071)\n Distributions.Normal{Float64}(μ=0.9497100114822388, σ=0.07750126389821968)\n Distributions.Normal{Float64}(μ=0.22462180256843567, σ=0.07103664786246695)\n Distributions.Normal{Float64}(μ=-0.7654240131378174, σ=0.05501397704409917)\n Distributions.Normal{Float64}(μ=1.0029183626174927, σ=0.07619466916431794)\n ⋮\n Distributions.Normal{Float64}(μ=0.7475956678390503, σ=0.049875919157527815)\n Distributions.Normal{Float64}(μ=0.019430622458457947, σ=0.07445076746045155)\n Distributions.Normal{Float64}(μ=-0.9451781511306763, σ=0.05929712369810892)\n Distributions.Normal{Float64}(μ=-0.9813591241836548, σ=0.05844012710417755)\n Distributions.Normal{Float64}(μ=-0.6470385789871216, σ=0.055754609087554294)\n Distributions.Normal{Float64}(μ=-0.34288135170936584, σ=0.05533523375842789)\n Distributions.Normal{Float64}(μ=0.9912381172180176, σ=0.07872473667398772)\n Distributions.Normal{Float64}(μ=-0.824547290802002, σ=0.05499258101374759)\n Distributions.Normal{Float64}(μ=-0.3306621015071869, σ=0.06745251908756716)\n Distributions.Normal{Float64}(μ=0.3742436170578003, σ=0.10588913330223387)\n Distributions.Normal{Float64}(μ=0.0875578224658966, σ=0.07436153828228255)\n Distributions.Normal{Float64}(μ=-0.34871187806129456, σ=0.06742745343084512)\n```\n:::\n:::\n\n\nthen we can plot the calibration plot of our neural model\n\n::: {.cell execution_count=9}\n``` {.julia .cell-code}\nCalibration_Plot(la,y_test,vec(predicted_distributions);n_bins = 20)\n```\n\n::: {.cell-output .cell-output-display}\n![](regression_files/figure-commonmark/cell-10-output-1.svg){}\n:::\n:::\n\n\nand compute the sharpness of the predictive distribution\n\n::: {.cell execution_count=10}\n``` {.julia .cell-code}\nsharpness_regression(vec(predicted_distributions))\n```\n\n::: {.cell-output .cell-output-display execution_count=11}\n```\n0.005058067743863281\n```\n:::\n:::\n\n\n", + "markdown": "```@meta\nCurrentModule = LaplaceRedux\n```\n## Libraries\n\nImport the libraries required to run this example\n\n::: {.cell execution_count=1}\n``` {.julia .cell-code}\nusing Pkg; Pkg.activate(\"docs\")\n# Import libraries\nusing Flux, Plots, TaijaPlotting, Random, Statistics, LaplaceRedux\ntheme(:wong)\n```\n:::\n\n\n## Data\n\nWe first generate some synthetic data:\n\n::: {.cell execution_count=2}\n``` {.julia .cell-code}\nusing LaplaceRedux.Data\nn = 3000 # number of observations\nσtrue = 0.30 # true observational noise\nx, y = Data.toy_data_regression(n;noise=σtrue,seed=1234)\nxs = [[x] for x in x]\nX = permutedims(x)\n```\n:::\n\n\nand split them in a training set and a test set\n\n::: {.cell execution_count=3}\n``` {.julia .cell-code}\n# Shuffle the data\nRandom.seed!(1234) # Set a seed for reproducibility\nshuffle_indices = shuffle(1:n)\n\n# Define split ratios\ntrain_ratio = 0.8\ntest_ratio = 0.2\n\n# Calculate split indices\ntrain_end = Int(floor(train_ratio * n))\n\n# Split the data\ntrain_indices = shuffle_indices[1:train_end]\ntest_indices = shuffle_indices[train_end+1:end]\n\n# Create the splits\nx_train, y_train = x[train_indices], y[train_indices]\nx_test, y_test = x[test_indices], y[test_indices]\n\n# Optional: Convert to desired format\nxs_train = [[x] for x in x_train]\nxs_test = [[x] for x in x_test]\nX_train = permutedims(x_train)\nX_test = permutedims(x_test)\n```\n:::\n\n\n## MLP\n\nWe set up a model and loss with weight regularization:\n\n::: {.cell execution_count=4}\n``` {.julia .cell-code}\ntrain_data = zip(xs_train,y_train)\nn_hidden = 50\nD = size(X,1)\nnn = Chain(\n Dense(D, n_hidden, tanh),\n Dense(n_hidden, 1)\n) \nloss(x, y) = Flux.Losses.mse(nn(x), y)\n```\n:::\n\n\nWe train the model:\n\n::: {.cell execution_count=5}\n``` {.julia .cell-code}\nusing Flux.Optimise: update!, Adam\nopt = Adam(1e-3)\nepochs = 1000\navg_loss(train_data) = mean(map(d -> loss(d[1],d[2]), train_data))\nshow_every = epochs/10\n\nfor epoch = 1:epochs\n for d in train_data\n gs = gradient(Flux.params(nn)) do\n l = loss(d...)\n end\n update!(opt, Flux.params(nn), gs)\n end\n if epoch % show_every == 0\n println(\"Epoch \" * string(epoch))\n @show avg_loss(train_data)\n end\nend\n```\n:::\n\n\n## Laplace Approximation\n\nLaplace approximation can be implemented as follows:\n\n::: {.cell execution_count=6}\n``` {.julia .cell-code}\nsubset_w = :all\nla = Laplace(nn; likelihood=:regression, subset_of_weights=subset_w)\nfit!(la, train_data)\nplot(la, X_train, y_train; zoom=-5, size=(400,400))\n```\n\n::: {.cell-output .cell-output-display execution_count=7}\n![](regression_files/figure-commonmark/cell-7-output-1.svg){}\n:::\n:::\n\n\nNext we optimize the prior precision $P_0$ and and observational noise $\\sigma$ using Empirical Bayes:\n\n::: {.cell execution_count=7}\n``` {.julia .cell-code}\noptimize_prior!(la; verbosity=1)\nplot(la, X_train, y_train; zoom=-5, size=(400,400))\n```\n\n::: {.cell-output .cell-output-stdout}\n```\nloss(exp.(logP₀), exp.(logσ)) = 668.3714946472106\nLog likelihood: -618.5175117610522\nLog det ratio: 68.76532606873238\nScatter: 30.942639703584522\nloss(exp.(logP₀), exp.(logσ)) = 719.2536119935747\nLog likelihood: -673.0996963447847\nLog det ratio: 76.53255037599948\nScatter: 15.775280921580569\nloss(exp.(logP₀), exp.(logσ)) = 574.605864472924\nLog likelihood: -528.694286608232\n```\n:::\n\n::: {.cell-output .cell-output-stdout}\n```\n\nLog det ratio: 80.73114330857285\nScatter: 11.092012420811196\nloss(exp.(logP₀), exp.(logσ)) = 568.4433850825203\nLog likelihood: -522.4407550111031\nLog det ratio: 82.10089958560243\nScatter: 9.90436055723207\n```\n:::\n\n::: {.cell-output .cell-output-stdout}\n```\n\nloss(exp.(logP₀), exp.(logσ)) = 566.9485255672008\nLog likelihood: -520.9682443835385\nLog det ratio: 81.84516297272847\nScatter: 10.11539939459612\nloss(exp.(logP₀), exp.(logσ)) = 559.9852101992792\nLog likelihood: -514.0625630685765\nLog det ratio: 80.97813304453496\nScatter: 10.867161216870441\n```\n:::\n\n::: {.cell-output .cell-output-stdout}\n```\nloss(exp.(logP₀), exp.(logσ)) = 559.1404593114019\nLog likelihood: -513.2449017869876\nLog det ratio: 80.16026747795866\nScatter: 11.630847570869795\nloss(exp.(logP₀), exp.(logσ)) = 559.3201392562346\nLog likelihood: -513.4273312363501\nLog det ratio: 79.68892769076004\nScatter: 12.096688349008877\n```\n:::\n\n::: {.cell-output .cell-output-stdout}\n```\n\nloss(exp.(logP₀), exp.(logσ)) = 559.2111983983311\nLog likelihood: -513.3174948065804\nLog det ratio: 79.56631681347287\nScatter: 12.2210903700287\nloss(exp.(logP₀), exp.(logσ)) = 559.1107459310829\nLog likelihood: -513.2176579845662\nLog det ratio: 79.63946732368183\nScatter: 12.146708569351494\n```\n:::\n\n::: {.cell-output .cell-output-display execution_count=8}\n![](regression_files/figure-commonmark/cell-8-output-6.svg){}\n:::\n:::\n\n\n## Calibration Plot\nOnce the prior precision has been optimized it is possible to evaluate the quality of the predictive distribution \nobtained through a calibration plot and a test dataset (y_test, X_test). \n\nFirst, we apply the trained network on the test dataset (y_test, X_test) and collect the neural network's predicted distributions \n\n::: {.cell execution_count=8}\n``` {.julia .cell-code}\npredicted_distributions= predict(la, X_test,ret_distr=true)\n```\n\n::: {.cell-output .cell-output-display execution_count=9}\n```\n600×1 Matrix{Distributions.Normal{Float64}}:\n Distributions.Normal{Float64}(μ=-0.1137533187866211, σ=0.07161056521032018)\n Distributions.Normal{Float64}(μ=0.7063850164413452, σ=0.050697938829269665)\n Distributions.Normal{Float64}(μ=-0.2211049497127533, σ=0.06876939416479119)\n Distributions.Normal{Float64}(μ=0.720299243927002, σ=0.08665125572287981)\n Distributions.Normal{Float64}(μ=-0.8338974714279175, σ=0.06464012115237727)\n Distributions.Normal{Float64}(μ=0.9910320043563843, σ=0.07452060172164382)\n Distributions.Normal{Float64}(μ=0.1507074236869812, σ=0.07316299850461126)\n Distributions.Normal{Float64}(μ=0.20875799655914307, σ=0.05507748397231652)\n Distributions.Normal{Float64}(μ=0.973572850227356, σ=0.07899004963915071)\n Distributions.Normal{Float64}(μ=0.9497100114822388, σ=0.07750126389821968)\n Distributions.Normal{Float64}(μ=0.22462180256843567, σ=0.07103664786246695)\n Distributions.Normal{Float64}(μ=-0.7654240131378174, σ=0.05501397704409917)\n Distributions.Normal{Float64}(μ=1.0029183626174927, σ=0.07619466916431794)\n ⋮\n Distributions.Normal{Float64}(μ=0.7475956678390503, σ=0.049875919157527815)\n Distributions.Normal{Float64}(μ=0.019430622458457947, σ=0.07445076746045155)\n Distributions.Normal{Float64}(μ=-0.9451781511306763, σ=0.05929712369810892)\n Distributions.Normal{Float64}(μ=-0.9813591241836548, σ=0.05844012710417755)\n Distributions.Normal{Float64}(μ=-0.6470385789871216, σ=0.055754609087554294)\n Distributions.Normal{Float64}(μ=-0.34288135170936584, σ=0.05533523375842789)\n Distributions.Normal{Float64}(μ=0.9912381172180176, σ=0.07872473667398772)\n Distributions.Normal{Float64}(μ=-0.824547290802002, σ=0.05499258101374759)\n Distributions.Normal{Float64}(μ=-0.3306621015071869, σ=0.06745251908756716)\n Distributions.Normal{Float64}(μ=0.3742436170578003, σ=0.10588913330223387)\n Distributions.Normal{Float64}(μ=0.0875578224658966, σ=0.07436153828228255)\n Distributions.Normal{Float64}(μ=-0.34871187806129456, σ=0.06742745343084512)\n```\n:::\n:::\n\n\nthen we can plot the calibration plot of our neural model\n\n::: {.cell execution_count=9}\n``` {.julia .cell-code}\nCalibration_Plot(la,y_test,vec(predicted_distributions);n_bins = 20)\n```\n\n::: {.cell-output .cell-output-display}\n![](regression_files/figure-commonmark/cell-10-output-1.svg){}\n:::\n:::\n\n\nand compute the sharpness of the predictive distribution\n\n::: {.cell execution_count=10}\n``` {.julia .cell-code}\nsharpness_regression(vec(predicted_distributions))\n```\n\n::: {.cell-output .cell-output-display execution_count=11}\n```\n0.005058067743863281\n```\n:::\n:::\n\n\n", "supporting": [ "regression_files\\figure-commonmark" ], diff --git a/dev/notebooks/KFAC/sb/Multi.jl.ipynb b/dev/notebooks/KFAC/sb/Multi.jl.ipynb index ec8e8fde..f8835838 100644 --- a/dev/notebooks/KFAC/sb/Multi.jl.ipynb +++ b/dev/notebooks/KFAC/sb/Multi.jl.ipynb @@ -3256,7 +3256,7 @@ " @ Zygote ~/.julia/packages/Zygote/HTsWj/src/compiler/interface.jl:384", " [35] gradient(f::Function, args::Params{Zygote.Buffer{Any, Vector{Any}}})", " @ Zygote ~/.julia/packages/Zygote/HTsWj/src/compiler/interface.jl:96", - " [36] optimize_prior!(la::LaplaceRedux.KronLaplace; n_steps::Int64, lr::Float64, λinit::Nothing, σinit::Nothing, verbose::Bool, tune_σ::Bool)", + " [36] optimize_prior!(la::LaplaceRedux.KronLaplace; n_steps::Int64, lr::Float64, λinit::Nothing, σinit::Nothing, verbosity::Int, tune_σ::Bool)", " @ LaplaceRedux ~/Builds/navimakarov/LaplaceRedux.jl/src/laplace.jl:420", " [37] optimize_prior!(la::LaplaceRedux.KronLaplace)", " @ LaplaceRedux ~/Builds/navimakarov/LaplaceRedux.jl/src/laplace.jl:391", diff --git a/dev/notebooks/batching/Multi-class_classification.jl.ipynb b/dev/notebooks/batching/Multi-class_classification.jl.ipynb index e0c99aef..23e54645 100644 --- a/dev/notebooks/batching/Multi-class_classification.jl.ipynb +++ b/dev/notebooks/batching/Multi-class_classification.jl.ipynb @@ -372,7 +372,7 @@ "la = Laplace(nn; likelihood=:classification)\n", "fit!(la, data)\n", "# TODO find out what the prior is\n", - "# optimize_prior!(la; verbose=true, n_steps=1000)" + "# optimize_prior!(la; verbosity=1, n_steps=1000)" ] }, { diff --git a/dev/notebooks/batching/Trials-01-Zygote.ipynb b/dev/notebooks/batching/Trials-01-Zygote.ipynb index cad86931..683cdf53 100644 --- a/dev/notebooks/batching/Trials-01-Zygote.ipynb +++ b/dev/notebooks/batching/Trials-01-Zygote.ipynb @@ -1446,7 +1446,7 @@ "if outdim == 1\n", " la = Laplace(nn; likelihood=likelihood, λ=λ, subset_of_weights=:last_layer)\n", " fit!(la, data)\n", - " # optimize_prior!(la; verbose=true)\n", + " # optimize_prior!(la; verbosity=1)\n", " plot(la, X, y, title=\"batchsize=N/A\") # standard\n", " # savefig(@sprintf(\"fig-01-%02d.png\", 0))\n", " # plot(la, X, y; xlims=(-5, 5), ylims=(-5, 5)) # lims\n", @@ -1869,7 +1869,7 @@ "if outdim == 1\n", " la = Laplace(nn; likelihood=likelihood, λ=λ, subset_of_weights=:last_layer)\n", " fit!(la, data)\n", - " # optimize_prior!(la; verbose=false)\n", + " # optimize_prior!(la; verbosity=0)\n", " plot(la, X, y, title=\"batchsize=$b\") # standard\n", " # savefig(@sprintf(\"fig-%02d.png\", batchsize))\n", " # plot(la, X, y; xlims=(-5, 5), ylims=(-5, 5)) # lims\n", @@ -2101,7 +2101,7 @@ " if outdim == 1\n", " la = Laplace(nn; likelihood=likelihood, λ=λ, subset_of_weights=:last_layer)\n", " fit!(la, data)\n", - " # optimize_prior!(la; verbose=true, show_every=10_000)\n", + " # optimize_prior!(la; verbosity=1, show_every=10_000)\n", " plot(la, X, y, title=\"batchsize=$batchsize\") # standard\n", " savefig(@sprintf(\"fig-01-%02d.png\", batchsize))\n", " # plot(la, X, y; xlims=(-5, 5), ylims=(-5, 5)) # lims\n", diff --git a/dev/notebooks/batching/Trials-03-Jacobians.ipynb b/dev/notebooks/batching/Trials-03-Jacobians.ipynb index 9411d345..d807aa91 100644 --- a/dev/notebooks/batching/Trials-03-Jacobians.ipynb +++ b/dev/notebooks/batching/Trials-03-Jacobians.ipynb @@ -1288,7 +1288,7 @@ "source": [ "# la = Laplace(nn; likelihood=:classification)\n", "# fit!(la, data)\n", - "# optimize_prior!(la; verbose=true, n_steps=1000)" + "# optimize_prior!(la; verbosity=1, n_steps=1000)" ] }, { diff --git a/dev/notebooks/batching/regression.ipynb b/dev/notebooks/batching/regression.ipynb index ed7251da..d09c3cd6 100644 --- a/dev/notebooks/batching/regression.ipynb +++ b/dev/notebooks/batching/regression.ipynb @@ -1945,7 +1945,7 @@ "source": [ "#| output: true\n", "\n", - "optimize_prior!(la; verbose=true)\n", + "optimize_prior!(la; verbosity=1)\n", "plot(la, X, y; zoom=-5, size=(400,400))" ] }, diff --git a/dev/notebooks/ggn/GGN-Julia.ipynb b/dev/notebooks/ggn/GGN-Julia.ipynb index 027bd6df..e9d282ba 100644 --- a/dev/notebooks/ggn/GGN-Julia.ipynb +++ b/dev/notebooks/ggn/GGN-Julia.ipynb @@ -245,7 +245,7 @@ "# Laplace(nn; likelihood=:classification, backend=:EmpiricalFisher) - to use Empirical Fisher as a backend \n", "la = Laplace(nn; likelihood=:classification) \n", "fit!(la, data)\n", - "optimize_prior!(la; verbose=true, n_steps=100)" + "optimize_prior!(la; verbosity=1, n_steps=100)" ] }, { diff --git a/dev/notebooks/multi-class/Multi-Class-Julia-FGD.ipynb b/dev/notebooks/multi-class/Multi-Class-Julia-FGD.ipynb index c1cc520c..d7ddceca 100644 --- a/dev/notebooks/multi-class/Multi-Class-Julia-FGD.ipynb +++ b/dev/notebooks/multi-class/Multi-Class-Julia-FGD.ipynb @@ -257,7 +257,7 @@ "using LaplaceRedux\n", "la = Laplace(nn; likelihood=:classification)\n", "fit!(la, data)\n", - "optimize_prior!(la; verbose=true, n_steps=1000)" + "optimize_prior!(la; verbosity=1, n_steps=1000)" ] }, { diff --git a/dev/notebooks/multi-class/Multi-Class-Julia-SGD.ipynb b/dev/notebooks/multi-class/Multi-Class-Julia-SGD.ipynb index 6127ec78..6b06343b 100644 --- a/dev/notebooks/multi-class/Multi-Class-Julia-SGD.ipynb +++ b/dev/notebooks/multi-class/Multi-Class-Julia-SGD.ipynb @@ -294,7 +294,7 @@ "using LaplaceRedux\n", "la = Laplace(nn; likelihood=:classification)\n", "fit!(la, data)\n", - "optimize_prior!(la; verbose=true, n_steps=100)" + "optimize_prior!(la; verbosity=1, n_steps=100)" ] }, { diff --git a/dev/notebooks/network_subsets/subnetworks_laplace.ipynb b/dev/notebooks/network_subsets/subnetworks_laplace.ipynb index cc7948c6..bf318dd9 100644 --- a/dev/notebooks/network_subsets/subnetworks_laplace.ipynb +++ b/dev/notebooks/network_subsets/subnetworks_laplace.ipynb @@ -2135,7 +2135,7 @@ "fit!(la, data)\n", "\n", "la_untuned = deepcopy(la) # saving for plotting\n", - "optimize_prior!(la; verbose=true, n_steps=500)\n", + "optimize_prior!(la; verbosity=1, n_steps=500)\n", "\n", "zoom=0\n", "println(\"...\")\n", @@ -4041,7 +4041,7 @@ "fit!(la, data)\n", "\n", "la_untuned = deepcopy(la) # saving for plotting\n", - "optimize_prior!(la; verbose=true, n_steps=500)\n", + "optimize_prior!(la; verbosity=1, n_steps=500)\n", "\n", "zoom=0\n", "println(\"...\")\n", @@ -5955,7 +5955,7 @@ "fit!(la, data)\n", "\n", "la_untuned = deepcopy(la) # saving for plotting\n", - "optimize_prior!(la; verbose=true, n_steps=500)\n", + "optimize_prior!(la; verbosity=1, n_steps=500)\n", "\n", "zoom=0\n", "println(\"...\")\n", @@ -7877,7 +7877,7 @@ "fit!(la, data)\n", "\n", "la_untuned = deepcopy(la) # saving for plotting\n", - "optimize_prior!(la; verbose=true, n_steps=500)\n", + "optimize_prior!(la; verbosity=1, n_steps=500)\n", "\n", "zoom=0\n", "println(\"...\")\n", diff --git a/docs/src/tutorials/logit.md b/docs/src/tutorials/logit.md index e7bd7034..341317c8 100644 --- a/docs/src/tutorials/logit.md +++ b/docs/src/tutorials/logit.md @@ -97,7 +97,7 @@ Laplace approximation for the posterior predictive can be implemented as follows la = Laplace(nn; likelihood=:classification, λ=λ, subset_of_weights=:last_layer) fit!(la, data) la_untuned = deepcopy(la) # saving for plotting -optimize_prior!(la; verbose=true, n_steps=500) +optimize_prior!(la; verbosity=1, n_steps=500) ``` The plot below shows the resulting posterior predictive surface for the plugin estimator (left) and the Laplace approximation (right). diff --git a/docs/src/tutorials/logit.qmd b/docs/src/tutorials/logit.qmd index bdafa613..6e111684 100644 --- a/docs/src/tutorials/logit.qmd +++ b/docs/src/tutorials/logit.qmd @@ -96,7 +96,7 @@ Laplace approximation for the posterior predictive can be implemented as follows la = Laplace(nn; likelihood=:classification, λ=λ, subset_of_weights=:last_layer) fit!(la, data) la_untuned = deepcopy(la) # saving for plotting -optimize_prior!(la; verbose=true, n_steps=500) +optimize_prior!(la; verbosity=1, n_steps=500) ``` The plot below shows the resulting posterior predictive surface for the plugin estimator (left) and the Laplace approximation (right). diff --git a/docs/src/tutorials/mlp.md b/docs/src/tutorials/mlp.md index 69576328..e0a3e896 100644 --- a/docs/src/tutorials/mlp.md +++ b/docs/src/tutorials/mlp.md @@ -93,7 +93,7 @@ Laplace approximation can be implemented as follows: la = Laplace(nn; likelihood=:classification, subset_of_weights=:all) fit!(la, data) la_untuned = deepcopy(la) # saving for plotting -optimize_prior!(la; verbose=true, n_steps=500) +optimize_prior!(la; verbosity=1, n_steps=500) ``` The plot below shows the resulting posterior predictive surface for the plugin estimator (left) and the Laplace approximation (right). diff --git a/docs/src/tutorials/mlp.qmd b/docs/src/tutorials/mlp.qmd index 59f9ad62..e1029d63 100644 --- a/docs/src/tutorials/mlp.qmd +++ b/docs/src/tutorials/mlp.qmd @@ -93,7 +93,7 @@ Laplace approximation can be implemented as follows: la = Laplace(nn; likelihood=:classification, subset_of_weights=:all) fit!(la, data) la_untuned = deepcopy(la) # saving for plotting -optimize_prior!(la; verbose=true, n_steps=500) +optimize_prior!(la; verbosity=1, n_steps=500) ``` The plot below shows the resulting posterior predictive surface for the plugin estimator (left) and the Laplace approximation (right). diff --git a/docs/src/tutorials/multi.md b/docs/src/tutorials/multi.md index fbb89aee..885112bf 100644 --- a/docs/src/tutorials/multi.md +++ b/docs/src/tutorials/multi.md @@ -97,7 +97,7 @@ The Laplace approximation can be implemented as follows: ``` julia la = Laplace(nn; likelihood=:classification) fit!(la, data) -optimize_prior!(la; verbose=true, n_steps=100) +optimize_prior!(la; verbosity=1, n_steps=100) ``` with either the probit approximation: diff --git a/docs/src/tutorials/multi.qmd b/docs/src/tutorials/multi.qmd index 6195a0ae..5ea5f019 100644 --- a/docs/src/tutorials/multi.qmd +++ b/docs/src/tutorials/multi.qmd @@ -97,7 +97,7 @@ The Laplace approximation can be implemented as follows: ```{julia} la = Laplace(nn; likelihood=:classification) fit!(la, data) -optimize_prior!(la; verbose=true, n_steps=100) +optimize_prior!(la; verbosity=1, n_steps=100) ``` with either the probit approximation: diff --git a/docs/src/tutorials/regression.md b/docs/src/tutorials/regression.md index 223f9b5e..b58d34d5 100644 --- a/docs/src/tutorials/regression.md +++ b/docs/src/tutorials/regression.md @@ -111,7 +111,7 @@ plot(la, X_train, y_train; zoom=-5, size=(400,400)) Next we optimize the prior precision $P_0$ and and observational noise $\sigma$ using Empirical Bayes: ``` julia -optimize_prior!(la; verbose=true) +optimize_prior!(la; verbosity=1) plot(la, X_train, y_train; zoom=-5, size=(400,400)) ``` diff --git a/docs/src/tutorials/regression.qmd b/docs/src/tutorials/regression.qmd index 6cce7c06..2faee5a3 100644 --- a/docs/src/tutorials/regression.qmd +++ b/docs/src/tutorials/regression.qmd @@ -110,7 +110,7 @@ Next we optimize the prior precision $P_0$ and and observational noise $\sigma$ ```{julia} #| output: true -optimize_prior!(la; verbose=true) +optimize_prior!(la; verbosity=1) plot(la, X_train, y_train; zoom=-5, size=(400,400)) ``` diff --git a/src/LaplaceRedux.jl b/src/LaplaceRedux.jl index 9a36d18e..ed931ca7 100644 --- a/src/LaplaceRedux.jl +++ b/src/LaplaceRedux.jl @@ -19,9 +19,6 @@ export fit!, predict export optimize_prior!, glm_predictive_distribution, posterior_covariance, posterior_precision -include("mlj_flux.jl") -export LaplaceClassification -export LaplaceRegression include("calibration_functions.jl") export empirical_frequency_binary_classification, diff --git a/src/baselaplace/optimize_prior.jl b/src/baselaplace/optimize_prior.jl index f5a16407..08f3c496 100644 --- a/src/baselaplace/optimize_prior.jl +++ b/src/baselaplace/optimize_prior.jl @@ -14,7 +14,7 @@ function optimize_prior!( lr::Real=1e-1, λinit::Union{Nothing,Real}=nothing, σinit::Union{Nothing,Real}=nothing, - verbose::Bool=false, + verbosity::Int=0, tune_σ::Bool=la.likelihood == :regression, ) @@ -42,7 +42,7 @@ function optimize_prior!( end Flux.Optimise.update!(opt, ps, gs) i += 1 - if verbose + if verbosity>0 if i % show_every == 0 @info "Iteration $(i): P₀=$(exp(logP₀[1])), σ=$(exp(logσ[1]))" @show loss(exp.(logP₀), exp.(logσ)) diff --git a/src/mlj_flux.jl b/src/mlj_flux.jl deleted file mode 100644 index 0b1fa68f..00000000 --- a/src/mlj_flux.jl +++ /dev/null @@ -1,494 +0,0 @@ -using Flux -using MLJFlux -using ProgressMeter: Progress, next!, BarGlyphs -using Random -using Tables -using LinearAlgebra -using LaplaceRedux -using ComputationalResources -using MLJBase: MLJBase -import MLJModelInterface as MMI -using Optimisers: Optimisers - -""" - MLJBase.@mlj_model mutable struct LaplaceRegression <: MLJFlux.MLJFluxProbabilistic - -A mutable struct representing a Laplace regression model that extends the `MLJFlux.MLJFluxProbabilistic` abstract type. -It uses Laplace approximation to estimate the posterior distribution of the weights of a neural network. - -The model is defined by the following default parameters for all `MLJFlux` models: - -- `builder`: a Flux model that constructs the neural network. -- `optimiser`: a Flux optimiser. -- `loss`: a loss function that takes the predicted output and the true output as arguments. -- `epochs`: the number of epochs. -- `batch_size`: the size of a batch. -- `lambda`: the regularization strength. -- `alpha`: the regularization mix (0 for all l2, 1 for all l1). -- `rng`: a random number generator. -- `optimiser_changes_trigger_retraining`: a boolean indicating whether changes in the optimiser trigger retraining. -- `acceleration`: the computational resource to use. - -The model also has the following parameters, which are specific to the Laplace approximation: - -- `subset_of_weights`: the subset of weights to use, either `:all`, `:last_layer`, or `:subnetwork`. -- `subnetwork_indices`: the indices of the subnetworks. -- `hessian_structure`: the structure of the Hessian matrix, either `:full` or `:diagonal`. -- `backend`: the backend to use, either `:GGN` or `:EmpiricalFisher`. -- `σ`: the standard deviation of the prior distribution. -- `μ₀`: the mean of the prior distribution. -- `P₀`: the covariance matrix of the prior distribution. -- `ret_distr`: a boolean that tells predict to either return distributions (true) objects from Distributions.jl or just the probabilities. -- `fit_prior_nsteps`: the number of steps used to fit the priors. -""" -MLJBase.@mlj_model mutable struct LaplaceRegression <: MLJFlux.MLJFluxProbabilistic - builder = MLJFlux.MLP(; hidden=(32, 32, 32), σ=Flux.swish) - optimiser = Optimisers.Adam() - loss = Flux.Losses.mse - epochs::Int = 10::(_ > 0) - batch_size::Int = 1::(_ > 0) - lambda::Float64 = 1.0 - alpha::Float64 = 0.0 - rng::Union{AbstractRNG,Int64} = Random.GLOBAL_RNG - optimiser_changes_trigger_retraining::Bool = false::(_ in (true, false)) - acceleration = CPU1()::(_ in (CPU1(), CUDALibs())) - subset_of_weights::Symbol = :all::(_ in (:all, :last_layer, :subnetwork)) - subnetwork_indices = nothing - hessian_structure::Union{HessianStructure,Symbol,String} = - :full::(_ in (:full, :diagonal)) - backend::Symbol = :GGN::(_ in (:GGN, :EmpiricalFisher)) - σ::Float64 = 1.0 - μ₀::Float64 = 0.0 - P₀::Union{AbstractMatrix,UniformScaling,Nothing} = nothing - ret_distr::Bool = false::(_ in (true, false)) - fit_prior_nsteps::Int = 100::(_ > 0) -end - -""" - MLJBase.@mlj_model mutable struct LaplaceClassification <: MLJFlux.MLJFluxProbabilistic - -A mutable struct representing a Laplace Classification model that extends the MLJFluxProbabilistic abstract type. -It uses Laplace approximation to estimate the posterior distribution of the weights of a neural network. - -The model is defined by the following default parameters for all `MLJFlux` models: -- `builder`: a Flux model that constructs the neural network. -- `finaliser`: a Flux model that processes the output of the neural network. -- `optimiser`: a Flux optimiser. -- `loss`: a loss function that takes the predicted output and the true output as arguments. -- `epochs`: the number of epochs. -- `batch_size`: the size of a batch. -- `lambda`: the regularization strength. -- `alpha`: the regularization mix (0 for all l2, 1 for all l1). -- `rng`: a random number generator. -- `optimiser_changes_trigger_retraining`: a boolean indicating whether changes in the optimiser trigger retraining. -- `acceleration`: the computational resource to use. - -The model also has the following parameters, which are specific to the Laplace approximation: - -- `subset_of_weights`: the subset of weights to use, either `:all`, `:last_layer`, or `:subnetwork`. -- `subnetwork_indices`: the indices of the subnetworks. -- `hessian_structure`: the structure of the Hessian matrix, either `:full` or `:diagonal`. -- `backend`: the backend to use, either `:GGN` or `:EmpiricalFisher`. -- `σ`: the standard deviation of the prior distribution. -- `μ₀`: the mean of the prior distribution. -- `P₀`: the covariance matrix of the prior distribution. -- `link_approx`: the link approximation to use, either `:probit` or `:plugin`. -- `predict_proba`: a boolean that select whether to predict probabilities or not. -- `ret_distr`: a boolean that tells predict to either return distributions (true) objects from Distributions.jl or just the probabilities. -- `fit_prior_nsteps`: the number of steps used to fit the priors. -""" -MLJBase.@mlj_model mutable struct LaplaceClassification <: MLJFlux.MLJFluxProbabilistic - builder = MLJFlux.MLP(; hidden=(32, 32, 32), σ=Flux.swish) - finaliser = Flux.softmax - optimiser = Optimisers.Adam() - loss = Flux.crossentropy - epochs::Int = 10::(_ > 0) - batch_size::Int = 1::(_ > 0) - lambda::Float64 = 1.0 - alpha::Float64 = 0.0 - rng::Union{AbstractRNG,Int64} = Random.GLOBAL_RNG - optimiser_changes_trigger_retraining::Bool = false::(_ in (true, false)) - acceleration = CPU1()::(_ in (CPU1(), CUDALibs())) - subset_of_weights::Symbol = :all::(_ in (:all, :last_layer, :subnetwork)) - subnetwork_indices::Vector{Vector{Int}} = Vector{Vector{Int}}([]) - hessian_structure::Union{HessianStructure,Symbol,String} = - :full::(_ in (:full, :diagonal)) - backend::Symbol = :GGN::(_ in (:GGN, :EmpiricalFisher)) - σ::Float64 = 1.0 - μ₀::Float64 = 0.0 - P₀::Union{AbstractMatrix,UniformScaling,Nothing} = nothing - link_approx::Symbol = :probit::(_ in (:probit, :plugin)) - predict_proba::Bool = true::(_ in (true, false)) - ret_distr::Bool = false::(_ in (true, false)) - fit_prior_nsteps::Int = 100::(_ > 0) -end - -const MLJ_Laplace = Union{LaplaceClassification,LaplaceRegression} - -""" - MLJFlux.shape(model::LaplaceRegression, X, y) - -Compute the the number of features of the X input dataset and the number of variables to predict from the y output dataset. - -# Arguments -- `model::LaplaceRegression`: The LaplaceRegression model to fit. -- `X`: The input data for training. -- `y`: The target labels for training one-hot encoded. - -# Returns -- (input size, output size) -""" -function MLJFlux.shape(model::LaplaceRegression, X, y) - X = X isa Tables.MatrixTable ? MLJBase.matrix(X) : X - n_input = size(X, 2) - dims = size(y) - if length(dims) == 1 - n_output = 1 - else - n_output = dims[1] - end - return (n_input, n_output) -end - -""" - MLJFlux.build(model::LaplaceRegression, rng, shape) - -Builds an MLJFlux model for Laplace regression compatible with the dimensions of the input and output layers specified by `shape`. - -# Arguments -- `model::LaplaceRegression`: The Laplace regression model. -- `rng`: A random number generator to ensure reproducibility. -- `shape`: A tuple or array specifying the dimensions of the input and output layers. - -# Returns -- The constructed MLJFlux model, compatible with the specified input and output dimensions. -""" -function MLJFlux.build(model::LaplaceRegression, rng, shape) - chain = MLJFlux.build(model.builder, rng, shape...) - return chain -end - -""" - MLJFlux.fitresult(model::LaplaceRegression, chain, y) - -Computes the fit result for a Laplace Regression model, returning the model chain and the number of output variables in the target data. - -# Arguments -- `model::LaplaceRegression`: The Laplace Regression model to be evaluated. -- `chain`: The trained model chain. -- `y`: The target data, typically a vector of class labels. - -# Returns - A tuple containing: - - The trained Flux chain. - - a deepcopy of the laplace model. -""" -function MLJFlux.fitresult(model::LaplaceRegression, chain, y) - return (chain, deepcopy(model)) -end - -""" - MLJFlux.train(model::LaplaceRegression, chain, regularized_optimiser, optimiser_state, epochs, verbosity, X, y) - -Fit the LaplaceRegression model using Flux.jl. - -# Arguments -- `model::LaplaceRegression`: The LaplaceRegression model. -- `regularized_optimiser`: the regularized optimiser to apply to the loss function. -- `optimiser_state`: thestate of the optimiser. -- `epochs`: The number of epochs for training. -- `verbosity`: The verbosity level for training. -- `X`: The input data for training. -- `y`: The target labels for training. - -# Returns (la, optimiser_state, history ) -where -- `la`: the fitted Laplace model. -- `optimiser_state`: the state of the optimiser. -- `history`: the training loss history. -""" -function MLJFlux.train( - model::LaplaceRegression, - chain, - regularized_optimiser, - optimiser_state, - epochs, - verbosity, - X, - y, -) - X = X isa Tables.MatrixTable ? MLJBase.matrix(X) : X - - # Initialize history: - history = [] - verbose_laplace = false - # intitialize and start progress meter: - meter = Progress( - epochs + 1; - dt=1.0, - desc="Optimising neural net:", - barglyphs=BarGlyphs("[=> ]"), - barlen=25, - color=:yellow, - ) - verbosity != 1 || next!(meter) - - # initiate history: - loss = model.loss - losses = (loss(chain(X[i]), y[i]) for i in 1:length(y)) - history = [mean(losses)] - - for i in 1:epochs - chain, optimiser_state, current_loss = MLJFlux.train_epoch( - model, chain, regularized_optimiser, optimiser_state, X, y - ) - verbosity < 2 || @info "Loss is $(round(current_loss; sigdigits=4))" - verbosity != 1 || next!(meter) - push!(history, current_loss) - end - - if !isa(chain, AbstractLaplace) - la = LaplaceRedux.Laplace( - chain; - likelihood=:regression, - subset_of_weights=model.subset_of_weights, - subnetwork_indices=model.subnetwork_indices, - hessian_structure=model.hessian_structure, - backend=model.backend, - σ=model.σ, - μ₀=model.μ₀, - P₀=model.P₀, - ) - else - la = chain - end - - # fit the Laplace model: - LaplaceRedux.fit!(la, zip(X, y)) - optimize_prior!(la; verbose=verbose_laplace, n_steps=model.fit_prior_nsteps) - - return la, optimiser_state, history -end - -""" - predict(model::LaplaceRegression, Xnew) - -Predict the output for new input data using a Laplace regression model. - -# Arguments -- `model::LaplaceRegression`: The trained Laplace regression model. -- the fitresult output produced by MLJFlux.fit! -- `Xnew`: The new input data. - -# Returns -- The predicted output for the new input data. - -""" -function MLJFlux.predict(model::LaplaceRegression, fitresult, Xnew) - Xnew = MLJBase.matrix(Xnew) |> permutedims - la = fitresult[1] - yhat = LaplaceRedux.predict(la, Xnew; ret_distr=model.ret_distr) - return yhat -end - -""" - MLJFlux.shape(model::LaplaceClassification, X, y) - -Compute the the number of features of the dataset X and the number of unique classes in y. - -# Arguments -- `model::LaplaceClassification`: The LaplaceClassification model to fit. -- `X`: The input data for training. -- `y`: The target labels for training one-hot encoded. - -# Returns -- (input size, output size) -""" - -function MLJFlux.shape(model::LaplaceClassification, X, y) - X = X isa Tables.MatrixTable ? MLJBase.matrix(X) : X - n_input = size(X, 2) - levels = unique(y) - n_output = length(levels) - return (n_input, n_output) -end - -""" - MLJFlux.build(model::LaplaceClassification, rng, shape) - -Builds an MLJFlux model for Laplace classification compatible with the dimensions of the input and output layers specified by `shape`. - -# Arguments -- `model::LaplaceClassification`: The Laplace classification model. -- `rng`: A random number generator to ensure reproducibility. -- `shape`: A tuple or array specifying the dimensions of the input and output layers. - -# Returns -- The constructed MLJFlux model, compatible with the specified input and output dimensions. -""" -function MLJFlux.build(model::LaplaceClassification, rng, shape) - chain = Flux.Chain(MLJFlux.build(model.builder, rng, shape...), model.finaliser) - - return chain -end - -""" - MLJFlux.fitresult(model::LaplaceClassification, chain, y) - -Computes the fit result for a Laplace classification model, returning the model chain and the number of unique classes in the target data. - -# Arguments -- `model::LaplaceClassification`: The Laplace classification model to be evaluated. -- `chain`: The trained model chain. -- `y`: The target data, typically a vector of class labels. - -# Returns -# Returns - A tuple containing: - - The trained Flux chain. - - a deepcopy of the laplace model. -""" -function MLJFlux.fitresult(model::LaplaceClassification, chain, y) - return (chain, deepcopy(model)) -end - -""" - MLJFlux.train(model::LaplaceClassification, chain, regularized_optimiser, optimiser_state, epochs, verbosity, X, y) - -Fit the LaplaceRegression model using Flux.jl. - -# Arguments -- `model::LaplaceClassification`: The LaplaceClassification model. -- `regularized_optimiser`: the regularized optimiser to apply to the loss function. -- `optimiser_state`: thestate of the optimiser. -- `epochs`: The number of epochs for training. -- `verbosity`: The verbosity level for training. -- `X`: The input data for training. -- `y`: The target labels for training. - -# Returns (fitresult, cache, report ) -where -- `la`: the fitted Laplace model. -- `optimiser_state`: the state of the optimiser. -- `history`: the training loss history. -""" -function MLJFlux.train( - model::LaplaceClassification, - chain, - regularized_optimiser, - optimiser_state, - epochs, - verbosity, - X, - y, -) - X = X isa Tables.MatrixTable ? MLJBase.matrix(X) : X - - # Initialize history: - history = [] - verbose_laplace = false - # intitialize and start progress meter: - meter = Progress( - epochs + 1; - dt=1.0, - desc="Optimising neural net:", - barglyphs=BarGlyphs("[=> ]"), - barlen=25, - color=:yellow, - ) - verbosity != 1 || next!(meter) - - # initiate history: - loss = model.loss - n_batches = length(y) - losses = (loss(chain(X[i]), y[i]) for i in 1:n_batches) - history = [mean(losses)] - - for i in 1:epochs - chain, optimiser_state, current_loss = MLJFlux.train_epoch( - model, chain, regularized_optimiser, optimiser_state, X, y - ) - verbosity < 2 || @info "Loss is $(round(current_loss; sigdigits=4))" - verbosity != 1 || next!(meter) - push!(history, current_loss) - end - - if !isa(chain, AbstractLaplace) - la = LaplaceRedux.Laplace( - chain; - likelihood=:regression, - subset_of_weights=model.subset_of_weights, - subnetwork_indices=model.subnetwork_indices, - hessian_structure=model.hessian_structure, - backend=model.backend, - σ=model.σ, - μ₀=model.μ₀, - P₀=model.P₀, - ) - else - la = chain - end - - # fit the Laplace model: - LaplaceRedux.fit!(la, zip(X, y)) - optimize_prior!(la; verbose=verbose_laplace, n_steps=model.fit_prior_nsteps) - - return la, optimiser_state, history -end - -""" - predict(model::LaplaceClassification, Xnew) - -Predicts the class labels for new data using the LaplaceClassification model. - -# Arguments -- `model::LaplaceClassification`: The trained LaplaceClassification model. -- fitresult: the fitresult output produced by MLJFlux.fit! -- `Xnew`: The new data to make predictions on. - -# Returns -An array of predicted class labels. - -""" -function MLJFlux.predict(model::LaplaceClassification, fitresult, Xnew) - la = fitresult[1] - Xnew = MLJBase.matrix(Xnew) |> permutedims - predictions = LaplaceRedux.predict( - la, - Xnew; - link_approx=model.link_approx, - predict_proba=model.predict_proba, - ret_distr=model.ret_distr, - ) - - return predictions -end - -# metadata for each model, -MLJBase.metadata_model( - LaplaceClassification; - input=Union{ - AbstractMatrix{MLJBase.Finite}, - MLJBase.Table(MLJBase.Finite), - AbstractMatrix{MLJBase.Continuous}, - MLJBase.Table(MLJBase.Continuous), - MLJBase.Table{AbstractVector{MLJBase.Continuous}}, - MLJBase.Table{AbstractVector{MLJBase.Finite}}, - }, - target=Union{AbstractArray{MLJBase.Finite},AbstractArray{MLJBase.Continuous}}, - path="MLJFlux.LaplaceClassification", -) -# metadata for each model, -MLJBase.metadata_model( - LaplaceRegression; - input=Union{ - AbstractMatrix{MLJBase.Continuous}, - MLJBase.Table(MLJBase.Continuous), - AbstractMatrix{MLJBase.Finite}, - MLJBase.Table(MLJBase.Finite), - MLJBase.Table{AbstractVector{MLJBase.Continuous}}, - MLJBase.Table{AbstractVector{MLJBase.Finite}}, - }, - target=Union{AbstractArray{MLJBase.Finite},AbstractArray{MLJBase.Continuous}}, - path="MLJFlux.LaplaceRegression", -) diff --git a/test/Manifest.toml b/test/Manifest.toml index 7f9e18a9..31dcb255 100644 --- a/test/Manifest.toml +++ b/test/Manifest.toml @@ -1,8 +1,8 @@ # This file is machine-generated - editing it directly is not advised -julia_version = "1.10.3" +julia_version = "1.10.5" manifest_format = "2.0" -project_hash = "30dc96d6146892242111894ebf221bf701ee0fdd" +project_hash = "2fde859c875aff2c1b66bd10b3f4f3d64f67067a" [[deps.AbstractFFTs]] deps = ["LinearAlgebra"] @@ -21,31 +21,35 @@ uuid = "1520ce14-60c1-5f80-bbc7-55ef81b5835c" version = "0.4.5" [[deps.Accessors]] -deps = ["CompositionsBase", "ConstructionBase", "Dates", "InverseFunctions", "LinearAlgebra", "MacroTools", "Markdown", "Test"] -git-tree-sha1 = "f61b15be1d76846c0ce31d3fcfac5380ae53db6a" +deps = ["CompositionsBase", "ConstructionBase", "InverseFunctions", "LinearAlgebra", "MacroTools", "Markdown"] +git-tree-sha1 = "b392ede862e506d451fc1616e79aa6f4c673dab8" uuid = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" -version = "0.1.37" +version = "0.1.38" [deps.Accessors.extensions] AccessorsAxisKeysExt = "AxisKeys" + AccessorsDatesExt = "Dates" AccessorsIntervalSetsExt = "IntervalSets" AccessorsStaticArraysExt = "StaticArrays" AccessorsStructArraysExt = "StructArrays" + AccessorsTestExt = "Test" AccessorsUnitfulExt = "Unitful" [deps.Accessors.weakdeps] AxisKeys = "94b1ba4f-4ee9-5380-92f1-94cde586c3c5" + Dates = "ade2ca70-3891-5945-98fb-dc099432e06a" IntervalSets = "8197267c-284f-5f27-9208-e0e47529a953" Requires = "ae029012-a4dd-5104-9daa-d747884805df" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a" + Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d" [[deps.Adapt]] deps = ["LinearAlgebra", "Requires"] -git-tree-sha1 = "6a55b747d1812e699320963ffde36f1ebdda4099" +git-tree-sha1 = "d80af0733c99ea80575f612813fa6aa71022d33a" uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" -version = "4.0.4" +version = "4.1.0" weakdeps = ["StaticArrays"] [deps.Adapt.extensions] @@ -59,9 +63,9 @@ version = "1.1.3" [[deps.Aqua]] deps = ["Compat", "Pkg", "Test"] -git-tree-sha1 = "12e575f31a6f233ba2485ed86b9325b85df37c61" +git-tree-sha1 = "49b1d7a9870c87ba13dc63f8ccfcf578cb266f95" uuid = "4c88cf16-eb10-579e-8560-4a9242c79595" -version = "0.8.7" +version = "0.8.9" [[deps.ArgCheck]] git-tree-sha1 = "a3a402a35a2f7e0b87828ccabbd5ebfbebe356b4" @@ -105,16 +109,11 @@ git-tree-sha1 = "2c7cc21e8678eff479978a0a2ef5ce2f51b63dff" uuid = "ab4f0b2a-ad5b-11e8-123f-65d77653426b" version = "0.5.0" -[[deps.BSON]] -git-tree-sha1 = "4c3e506685c527ac6a54ccc0c8c76fd6f91b42fb" -uuid = "fbb218c0-5317-5bc6-957e-2ee96dd4b1f0" -version = "0.3.9" - [[deps.BangBang]] -deps = ["Accessors", "Compat", "ConstructionBase", "InitialValues", "LinearAlgebra", "Requires"] -git-tree-sha1 = "08e5fc6620a8d83534bf6149795054f1b1e8370a" +deps = ["Accessors", "ConstructionBase", "InitialValues", "LinearAlgebra", "Requires"] +git-tree-sha1 = "e2144b631226d9eeab2d746ca8880b7ccff504ae" uuid = "198e06fe-97b7-11e9-32a5-e1d131e6ad66" -version = "0.4.2" +version = "0.4.3" [deps.BangBang.extensions] BangBangChainRulesCoreExt = "ChainRulesCore" @@ -146,15 +145,15 @@ uuid = "d1d4a3ce-64b1-5f1a-9ba4-7e7e69966f35" version = "0.1.9" [[deps.BufferedStreams]] -git-tree-sha1 = "4ae47f9a4b1dc19897d3743ff13685925c5202ec" +git-tree-sha1 = "6863c5b7fc997eadcabdbaf6c5f201dc30032643" uuid = "e1450e63-4bb3-523b-b2a4-4ffa8c0fd77d" -version = "1.2.1" +version = "1.2.2" [[deps.Bzip2_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "9e2a6b69137e6969bab0152632dcb3bc108c8bdd" +git-tree-sha1 = "8873e196c2eb87962a2048b3b8e08946535864a1" uuid = "6e34b625-4abd-537c-b88f-471c36dfa7a0" -version = "1.0.8+1" +version = "1.0.8+2" [[deps.CEnum]] git-tree-sha1 = "389ad5c84de1ae7cf0e28e381131c98ea87d54fc" @@ -163,21 +162,15 @@ version = "0.5.0" [[deps.CSV]] deps = ["CodecZlib", "Dates", "FilePathsBase", "InlineStrings", "Mmap", "Parsers", "PooledArrays", "PrecompileTools", "SentinelArrays", "Tables", "Unicode", "WeakRefStrings", "WorkerUtilities"] -git-tree-sha1 = "6c834533dc1fabd820c1db03c839bf97e45a3fab" +git-tree-sha1 = "deddd8725e5e1cc49ee205a1964256043720a6c3" uuid = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b" -version = "0.10.14" +version = "0.10.15" [[deps.Cairo_jll]] deps = ["Artifacts", "Bzip2_jll", "CompilerSupportLibraries_jll", "Fontconfig_jll", "FreeType2_jll", "Glib_jll", "JLLWrappers", "LZO_jll", "Libdl", "Pixman_jll", "Xorg_libXext_jll", "Xorg_libXrender_jll", "Zlib_jll", "libpng_jll"] -git-tree-sha1 = "a2f1c8c668c8e3cb4cca4e57a8efdb09067bb3fd" +git-tree-sha1 = "009060c9a6168704143100f36ab08f06c2af4642" uuid = "83423d85-b0ee-5818-9007-b63ccbeb887a" -version = "1.18.0+2" - -[[deps.Calculus]] -deps = ["LinearAlgebra"] -git-tree-sha1 = "f641eb0a4f00c343bbc32346e1217b86f3ce9dad" -uuid = "49dc2e85-a5d0-5ad3-a950-438e2897f1b9" -version = "0.5.1" +version = "1.18.2+1" [[deps.CategoricalArrays]] deps = ["DataAPI", "Future", "Missings", "Printf", "Requires", "Statistics", "Unicode"] @@ -206,15 +199,15 @@ version = "0.1.15" [[deps.ChainRules]] deps = ["Adapt", "ChainRulesCore", "Compat", "Distributed", "GPUArraysCore", "IrrationalConstants", "LinearAlgebra", "Random", "RealDot", "SparseArrays", "SparseInverseSubset", "Statistics", "StructArrays", "SuiteSparse"] -git-tree-sha1 = "227985d885b4dbce5e18a96f9326ea1e836e5a03" +git-tree-sha1 = "be227d253d132a6d57f9ccf5f67c0fb6488afd87" uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" -version = "1.69.0" +version = "1.71.0" [[deps.ChainRulesCore]] deps = ["Compat", "LinearAlgebra"] -git-tree-sha1 = "71acdbf594aab5bbb2cec89b208c41b4c411e49f" +git-tree-sha1 = "3e4b134270b372f2ed4d4d0e936aabaefc1802bc" uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -version = "1.24.0" +version = "1.25.0" weakdeps = ["SparseArrays"] [deps.ChainRulesCore.extensions] @@ -234,9 +227,9 @@ version = "0.10.4+0" [[deps.CodecZlib]] deps = ["TranscodingStreams", "Zlib_jll"] -git-tree-sha1 = "b8fe8546d52ca154ac556809e10c75e6e7430ac8" +git-tree-sha1 = "bce6804e5e6044c6daab27bb533d1295e4a2e759" uuid = "944b1d66-785c-5afd-91f1-9de20f533193" -version = "0.7.5" +version = "0.7.6" [[deps.ColorSchemes]] deps = ["ColorTypes", "ColorVectorSpace", "Colors", "FixedPointNumbers", "PrecompileTools", "Random"] @@ -272,16 +265,16 @@ uuid = "861a8166-3701-5b0c-9a16-15d98fcdc6aa" version = "1.0.2" [[deps.CommonSubexpressions]] -deps = ["MacroTools", "Test"] -git-tree-sha1 = "7b8a93dba8af7e3b42fecabf646260105ac373f7" +deps = ["MacroTools"] +git-tree-sha1 = "cda2cfaebb4be89c9084adaca7dd7333369715c5" uuid = "bbf7d656-a473-5ed7-a52c-81e309532950" -version = "0.3.0" +version = "0.3.1" [[deps.Compat]] deps = ["TOML", "UUIDs"] -git-tree-sha1 = "b1c55339b7c6c350ee89f2c1604299660525b248" +git-tree-sha1 = "8ae8d32e09f0dcf42a36b90d4e17f5dd2e4c4215" uuid = "34da2185-b29b-5c13-b0c7-acf172513d20" -version = "4.15.0" +version = "4.16.0" weakdeps = ["Dates", "LinearAlgebra"] [deps.Compat.extensions] @@ -313,17 +306,18 @@ uuid = "f0e56b4a-5159-44fe-b623-3e5288b988bb" version = "2.4.2" [[deps.ConstructionBase]] -deps = ["LinearAlgebra"] -git-tree-sha1 = "d8a9c0b6ac2d9081bf76324b39c78ca3ce4f0c98" +git-tree-sha1 = "76219f1ed5771adbb096743bff43fb5fdd4c1157" uuid = "187b0558-2788-49d3-abe0-74a17ed4e7c9" -version = "1.5.6" +version = "1.5.8" [deps.ConstructionBase.extensions] ConstructionBaseIntervalSetsExt = "IntervalSets" + ConstructionBaseLinearAlgebraExt = "LinearAlgebra" ConstructionBaseStaticArraysExt = "StaticArrays" [deps.ConstructionBase.weakdeps] IntervalSets = "8197267c-284f-5f27-9208-e0e47529a953" + LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" [[deps.ContextVariablesX]] @@ -370,10 +364,10 @@ uuid = "124859b0-ceae-595e-8997-d05f6a7a8dfe" version = "0.7.13" [[deps.DataFrames]] -deps = ["Compat", "DataAPI", "DataStructures", "Future", "InlineStrings", "InvertedIndices", "IteratorInterfaceExtensions", "LinearAlgebra", "Markdown", "Missings", "PooledArrays", "PrecompileTools", "PrettyTables", "Printf", "REPL", "Random", "Reexport", "SentinelArrays", "SortingAlgorithms", "Statistics", "TableTraits", "Tables", "Unicode"] -git-tree-sha1 = "04c738083f29f86e62c8afc341f0967d8717bdb8" +deps = ["Compat", "DataAPI", "DataStructures", "Future", "InlineStrings", "InvertedIndices", "IteratorInterfaceExtensions", "LinearAlgebra", "Markdown", "Missings", "PooledArrays", "PrecompileTools", "PrettyTables", "Printf", "Random", "Reexport", "SentinelArrays", "SortingAlgorithms", "Statistics", "TableTraits", "Tables", "Unicode"] +git-tree-sha1 = "fb61b4812c49343d7ef0b533ba982c46021938a6" uuid = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" -version = "1.6.1" +version = "1.7.0" [[deps.DataStructures]] deps = ["Compat", "InteractiveUtils", "OrderedCollections"] @@ -390,6 +384,12 @@ version = "1.0.0" deps = ["Printf"] uuid = "ade2ca70-3891-5945-98fb-dc099432e06a" +[[deps.Dbus_jll]] +deps = ["Artifacts", "Expat_jll", "JLLWrappers", "Libdl"] +git-tree-sha1 = "fc173b380865f70627d7dd1190dc2fce6cc105af" +uuid = "ee1fde0b-3d02-5ea6-8484-8dfef6360eab" +version = "1.14.10+0" + [[deps.DecisionTree]] deps = ["AbstractTrees", "DelimitedFiles", "LinearAlgebra", "Random", "ScikitLearnBase", "Statistics"] git-tree-sha1 = "526ca14aaaf2d5a0e242f3a8a7966eb9065d7d78" @@ -436,9 +436,9 @@ uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b" [[deps.Distributions]] deps = ["AliasTables", "FillArrays", "LinearAlgebra", "PDMats", "Printf", "QuadGK", "Random", "SpecialFunctions", "Statistics", "StatsAPI", "StatsBase", "StatsFuns"] -git-tree-sha1 = "9c405847cc7ecda2dc921ccf18b47ca150d7317e" +git-tree-sha1 = "d7477ecdafb813ddee2ae727afa94e9dcb5f3fb0" uuid = "31c24e10-a181-5473-b8eb-7969acd0382f" -version = "0.25.109" +version = "0.25.112" [deps.Distributions.extensions] DistributionsChainRulesCoreExt = "ChainRulesCore" @@ -461,12 +461,6 @@ deps = ["ArgTools", "FileWatching", "LibCURL", "NetworkOptions"] uuid = "f43a241f-c20a-4ad4-852c-f6b1247861c6" version = "1.6.0" -[[deps.DualNumbers]] -deps = ["Calculus", "NaNMath", "SpecialFunctions"] -git-tree-sha1 = "5837a837389fccf076445fce071c8ddaea35a566" -uuid = "fa6b7ba4-c1ee-5f82-b5fc-ecf0adba8f74" -version = "0.6.8" - [[deps.EpollShim_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl"] git-tree-sha1 = "8e9441ee83492030ace98f9789a654a6d0b1f643" @@ -487,9 +481,9 @@ version = "2.6.2+0" [[deps.FFMPEG]] deps = ["FFMPEG_jll"] -git-tree-sha1 = "b57e3acbe22f8484b4b5ff66a7499717fe1a9cc8" +git-tree-sha1 = "53ebe7511fa11d33bec688a9178fac4e49eeee00" uuid = "c87230d0-a227-11e9-1b43-d7ebe4e7570a" -version = "0.4.1" +version = "0.4.2" [[deps.FFMPEG_jll]] deps = ["Artifacts", "Bzip2_jll", "FreeType2_jll", "FriBidi_jll", "JLLWrappers", "LAME_jll", "Libdl", "Ogg_jll", "OpenSSL_jll", "Opus_jll", "PCRE2_jll", "Zlib_jll", "libaom_jll", "libass_jll", "libfdk_aac_jll", "libvorbis_jll", "x264_jll", "x265_jll"] @@ -511,24 +505,29 @@ version = "0.1.1" [[deps.FileIO]] deps = ["Pkg", "Requires", "UUIDs"] -git-tree-sha1 = "82d8afa92ecf4b52d78d869f038ebfb881267322" +git-tree-sha1 = "62ca0547a14c57e98154423419d8a342dca75ca9" uuid = "5789e2e9-d7fb-5bc7-8068-2c6fae9b9549" -version = "1.16.3" +version = "1.16.4" [[deps.FilePathsBase]] -deps = ["Compat", "Dates", "Mmap", "Printf", "Test", "UUIDs"] -git-tree-sha1 = "9f00e42f8d99fdde64d40c8ea5d14269a2e2c1aa" +deps = ["Compat", "Dates"] +git-tree-sha1 = "7878ff7172a8e6beedd1dea14bd27c3c6340d361" uuid = "48062228-2e41-5def-b9a4-89aafe57970f" -version = "0.9.21" +version = "0.9.22" +weakdeps = ["Mmap", "Test"] + + [deps.FilePathsBase.extensions] + FilePathsBaseMmapExt = "Mmap" + FilePathsBaseTestExt = "Test" [[deps.FileWatching]] uuid = "7b1f6079-737a-58dc-b8bc-7a2ca5c1b5ee" [[deps.FillArrays]] deps = ["LinearAlgebra"] -git-tree-sha1 = "0653c0a2396a6da5bc4766c43041ef5fd3efbe57" +git-tree-sha1 = "6a70198746448456524cb442b8af316927ff3e1a" uuid = "1a297f60-69ca-5386-bcde-b61e274b549b" -version = "1.11.0" +version = "1.13.0" weakdeps = ["PDMats", "SparseArrays", "Statistics"] [deps.FillArrays.extensions] @@ -543,21 +542,27 @@ uuid = "53c48c17-4a7d-5ca2-90c5-79b7896eea93" version = "0.8.5" [[deps.Flux]] -deps = ["Adapt", "ChainRulesCore", "Compat", "Functors", "LinearAlgebra", "MLUtils", "MacroTools", "NNlib", "OneHotArrays", "Optimisers", "Preferences", "ProgressLogging", "Random", "Reexport", "SparseArrays", "SpecialFunctions", "Statistics", "Zygote"] -git-tree-sha1 = "edacf029ed6276301e455e34d7ceeba8cc34078a" +deps = ["Adapt", "ChainRulesCore", "Compat", "Functors", "LinearAlgebra", "MLDataDevices", "MLUtils", "MacroTools", "NNlib", "OneHotArrays", "Optimisers", "Preferences", "ProgressLogging", "Random", "Reexport", "Setfield", "SparseArrays", "SpecialFunctions", "Statistics", "Zygote"] +git-tree-sha1 = "37fa32a50c69c10c6ea1465d3054d98c75bd7777" uuid = "587475ba-b771-5e3f-ad9e-33799f191a9c" -version = "0.14.16" +version = "0.14.22" [deps.Flux.extensions] FluxAMDGPUExt = "AMDGPU" FluxCUDAExt = "CUDA" FluxCUDAcuDNNExt = ["CUDA", "cuDNN"] + FluxEnzymeExt = "Enzyme" + FluxMPIExt = "MPI" + FluxMPINCCLExt = ["CUDA", "MPI", "NCCL"] FluxMetalExt = "Metal" [deps.Flux.weakdeps] AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" + Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" + MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195" Metal = "dde4c033-4e86-420c-a63e-0dd931031962" + NCCL = "3fe64909-d7a1-4096-9b7d-7a0f12cf0f6b" cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" [[deps.Fontconfig_jll]] @@ -595,25 +600,25 @@ version = "1.0.14+0" [[deps.Functors]] deps = ["LinearAlgebra"] -git-tree-sha1 = "8a66c07630d6428eaab3506a0eabfcf4a9edea05" +git-tree-sha1 = "64d8e93700c7a3f28f717d265382d52fac9fa1c1" uuid = "d9f16b24-f501-4c13-a1f2-28368ffc5196" -version = "0.4.11" +version = "0.4.12" [[deps.Future]] deps = ["Random"] uuid = "9fa8497b-333b-5362-9e8d-4d0656e87820" [[deps.GLFW_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Libglvnd_jll", "Xorg_libXcursor_jll", "Xorg_libXi_jll", "Xorg_libXinerama_jll", "Xorg_libXrandr_jll", "xkbcommon_jll"] -git-tree-sha1 = "3f74912a156096bd8fdbef211eff66ab446e7297" +deps = ["Artifacts", "JLLWrappers", "Libdl", "Libglvnd_jll", "Xorg_libXcursor_jll", "Xorg_libXi_jll", "Xorg_libXinerama_jll", "Xorg_libXrandr_jll", "libdecor_jll", "xkbcommon_jll"] +git-tree-sha1 = "532f9126ad901533af1d4f5c198867227a7bb077" uuid = "0656b61e-2033-5cc2-a64a-77c0f6c09b89" -version = "3.4.0+0" +version = "3.4.0+1" [[deps.GPUArrays]] deps = ["Adapt", "GPUArraysCore", "LLVM", "LinearAlgebra", "Printf", "Random", "Reexport", "Serialization", "Statistics"] -git-tree-sha1 = "a74c3f1cf56a3dfcdef0605f8cdb7015926aae30" +git-tree-sha1 = "62ee71528cca49be797076a76bdc654a170a523e" uuid = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" -version = "10.3.0" +version = "10.3.1" [[deps.GPUArraysCore]] deps = ["Adapt"] @@ -623,15 +628,15 @@ version = "0.1.6" [[deps.GR]] deps = ["Artifacts", "Base64", "DelimitedFiles", "Downloads", "GR_jll", "HTTP", "JSON", "Libdl", "LinearAlgebra", "Preferences", "Printf", "Qt6Wayland_jll", "Random", "Serialization", "Sockets", "TOML", "Tar", "Test", "p7zip_jll"] -git-tree-sha1 = "629693584cef594c3f6f99e76e7a7ad17e60e8d5" +git-tree-sha1 = "ee28ddcd5517d54e417182fec3886e7412d3926f" uuid = "28b8d3ca-fb5f-59d9-8090-bfdbd6d07a71" -version = "0.73.7" +version = "0.73.8" [[deps.GR_jll]] deps = ["Artifacts", "Bzip2_jll", "Cairo_jll", "FFMPEG_jll", "Fontconfig_jll", "FreeType2_jll", "GLFW_jll", "JLLWrappers", "JpegTurbo_jll", "Libdl", "Libtiff_jll", "Pixman_jll", "Qt6Base_jll", "Zlib_jll", "libpng_jll"] -git-tree-sha1 = "a8863b69c2a0859f2c2c87ebdc4c6712e88bdf0d" +git-tree-sha1 = "f31929b9e67066bee48eec8b03c0df47d31a74b3" uuid = "d2c73de3-f751-5644-a686-071e5b155ba9" -version = "0.73.7+0" +version = "0.73.8+0" [[deps.GZip]] deps = ["Libdl", "Zlib_jll"] @@ -647,9 +652,9 @@ version = "0.21.0+0" [[deps.Glib_jll]] deps = ["Artifacts", "Gettext_jll", "JLLWrappers", "Libdl", "Libffi_jll", "Libiconv_jll", "Libmount_jll", "PCRE2_jll", "Zlib_jll"] -git-tree-sha1 = "7c82e6a6cd34e9d935e9aa4051b66c6ff3af59ba" +git-tree-sha1 = "674ff0db93fffcd11a3573986e550d66cd4fd71f" uuid = "7746bdde-850d-59dc-9ae8-88ece973131d" -version = "2.80.2+0" +version = "2.80.5+0" [[deps.Glob]] git-tree-sha1 = "97285bbd5230dd766e9ef6749b80fc617126d496" @@ -680,10 +685,10 @@ version = "0.17.2" MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195" [[deps.HDF5_jll]] -deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "LazyArtifacts", "LibCURL_jll", "Libdl", "MPICH_jll", "MPIPreferences", "MPItrampoline_jll", "MicrosoftMPI_jll", "OpenMPI_jll", "OpenSSL_jll", "TOML", "Zlib_jll", "libaec_jll"] -git-tree-sha1 = "82a471768b513dc39e471540fdadc84ff80ff997" +deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "LLVMOpenMP_jll", "LazyArtifacts", "LibCURL_jll", "Libdl", "MPICH_jll", "MPIPreferences", "MPItrampoline_jll", "MicrosoftMPI_jll", "OpenMPI_jll", "OpenSSL_jll", "TOML", "Zlib_jll", "libaec_jll"] +git-tree-sha1 = "38c8874692d48d5440d5752d6c74b0c6b0b60739" uuid = "0234f1f7-429e-5d53-9886-15a909be8d59" -version = "1.14.3+3" +version = "1.14.2+1" [[deps.HTTP]] deps = ["Base64", "CodecZlib", "ConcurrentUtilities", "Dates", "ExceptionUnwrapping", "Logging", "LoggingExtras", "MbedTLS", "NetworkOptions", "OpenSSL", "Random", "SimpleBufferStream", "Sockets", "URIs", "UUIDs"] @@ -692,22 +697,22 @@ uuid = "cd3eb016-35fb-5094-929b-558a96fad6f3" version = "1.10.8" [[deps.HarfBuzz_jll]] -deps = ["Artifacts", "Cairo_jll", "Fontconfig_jll", "FreeType2_jll", "Glib_jll", "Graphite2_jll", "JLLWrappers", "Libdl", "Libffi_jll", "Pkg"] -git-tree-sha1 = "129acf094d168394e80ee1dc4bc06ec835e510a3" +deps = ["Artifacts", "Cairo_jll", "Fontconfig_jll", "FreeType2_jll", "Glib_jll", "Graphite2_jll", "JLLWrappers", "Libdl", "Libffi_jll"] +git-tree-sha1 = "401e4f3f30f43af2c8478fc008da50096ea5240f" uuid = "2e76f6c2-a576-52d4-95c1-20adfe4de566" -version = "2.8.1+1" +version = "8.3.1+0" [[deps.Hwloc_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl"] -git-tree-sha1 = "5e19e1e4fa3e71b774ce746274364aef0234634e" +git-tree-sha1 = "dd3b49277ec2bb2c6b94eb1604d4d0616016f7a6" uuid = "e33a78d0-f292-5ffc-b300-72abe9b543c8" -version = "2.11.1+0" +version = "2.11.2+0" [[deps.HypergeometricFunctions]] -deps = ["DualNumbers", "LinearAlgebra", "OpenLibm_jll", "SpecialFunctions"] -git-tree-sha1 = "f218fe3736ddf977e0e772bc9a586b2383da2685" +deps = ["LinearAlgebra", "OpenLibm_jll", "SpecialFunctions"] +git-tree-sha1 = "7c4195be1649ae622304031ed46a2f4df989f1eb" uuid = "34004b35-14d8-5ef3-9330-4cdb6864b03a" -version = "0.3.23" +version = "0.3.24" [[deps.IRTools]] deps = ["InteractiveUtils", "MacroTools"] @@ -762,14 +767,14 @@ uuid = "7d512f48-7fb1-5a58-b986-67e6dc259f01" version = "0.7.0" [[deps.InverseFunctions]] -deps = ["Test"] -git-tree-sha1 = "18c59411ece4838b18cd7f537e56cf5e41ce5bfd" +git-tree-sha1 = "a779299d77cd080bf77b97535acecd73e1c5e5cb" uuid = "3587e190-3f89-42d0-90ee-14403ec27112" -version = "0.1.15" -weakdeps = ["Dates"] +version = "0.1.17" +weakdeps = ["Dates", "Test"] [deps.InverseFunctions.extensions] - DatesExt = "Dates" + InverseFunctionsDatesExt = "Dates" + InverseFunctionsTestExt = "Test" [[deps.InvertedIndices]] git-tree-sha1 = "0dc7b50b8d436461be01300fd8cd45aa0274b038" @@ -787,22 +792,22 @@ uuid = "82899510-4779-5014-852e-03e436cf321d" version = "1.0.0" [[deps.JLD2]] -deps = ["FileIO", "MacroTools", "Mmap", "OrderedCollections", "Pkg", "PrecompileTools", "Reexport", "Requires", "TranscodingStreams", "UUIDs", "Unicode"] -git-tree-sha1 = "5fe858cb863e211c6dedc8cce2dc0752d4ab6e2b" +deps = ["FileIO", "MacroTools", "Mmap", "OrderedCollections", "PrecompileTools", "Requires", "TranscodingStreams"] +git-tree-sha1 = "b464b9b461ee989b435a689a4f7d870b68d467ed" uuid = "033835bb-8acc-5ee8-8aae-3f567f8a3819" -version = "0.4.50" +version = "0.5.6" [[deps.JLFzf]] deps = ["Pipe", "REPL", "Random", "fzf_jll"] -git-tree-sha1 = "a53ebe394b71470c7f97c2e7e170d51df21b17af" +git-tree-sha1 = "39d64b09147620f5ffbf6b2d3255be3c901bec63" uuid = "1019f520-868f-41f5-a6de-eb00f4b6a39c" -version = "0.1.7" +version = "0.1.8" [[deps.JLLWrappers]] deps = ["Artifacts", "Preferences"] -git-tree-sha1 = "7e5d6779a1e09a36db2a7b6cff50942a0a7d0fca" +git-tree-sha1 = "be3dc50a92e5a386872a493a10050136d4703f9b" uuid = "692b3bcd-3c85-4b1f-b108-f13ce0eb3210" -version = "1.5.0" +version = "1.6.1" [[deps.JSON]] deps = ["Dates", "Mmap", "Parsers", "Unicode"] @@ -812,9 +817,9 @@ version = "0.21.4" [[deps.JSON3]] deps = ["Dates", "Mmap", "Parsers", "PrecompileTools", "StructTypes", "UUIDs"] -git-tree-sha1 = "eb3edce0ed4fa32f75a0a11217433c31d56bd48b" +git-tree-sha1 = "1d322381ef7b087548321d3f878cb4c9bd8f8f9b" uuid = "0f8b85d8-7281-11e9-16c2-39a750bddbf1" -version = "1.14.0" +version = "1.14.1" [deps.JSON3.extensions] JSON3ArrowExt = ["ArrowTypes"] @@ -824,9 +829,9 @@ version = "1.14.0" [[deps.JpegTurbo_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl"] -git-tree-sha1 = "c84a835e1a09b289ffcd2271bf2a337bbdda6637" +git-tree-sha1 = "25ee0be4d43d0269027024d75a24c24d6c6e590c" uuid = "aacddb02-875f-59d6-b918-886e6ef4fbf8" -version = "3.0.3+0" +version = "3.0.4+0" [[deps.JuliaVariables]] deps = ["MLStyle", "NameResolution"] @@ -835,16 +840,20 @@ uuid = "b14d175d-62b4-44ba-8fb7-3064adc8c3ec" version = "0.2.4" [[deps.KernelAbstractions]] -deps = ["Adapt", "Atomix", "InteractiveUtils", "LinearAlgebra", "MacroTools", "PrecompileTools", "Requires", "SparseArrays", "StaticArrays", "UUIDs", "UnsafeAtomics", "UnsafeAtomicsLLVM"] -git-tree-sha1 = "d0448cebd5919e06ca5edc7a264631790de810ec" +deps = ["Adapt", "Atomix", "InteractiveUtils", "MacroTools", "PrecompileTools", "Requires", "StaticArrays", "UUIDs", "UnsafeAtomics", "UnsafeAtomicsLLVM"] +git-tree-sha1 = "04e52f596d0871fa3890170fa79cb15e481e4cd8" uuid = "63c18a36-062a-441e-b654-da1e3ab1ce7c" -version = "0.9.22" +version = "0.9.28" [deps.KernelAbstractions.extensions] EnzymeExt = "EnzymeCore" + LinearAlgebraExt = "LinearAlgebra" + SparseArraysExt = "SparseArrays" [deps.KernelAbstractions.weakdeps] EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" + LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" + SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" [[deps.LAME_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl"] @@ -853,16 +862,16 @@ uuid = "c1c5ebd0-6772-5130-a774-d5fcae4a789d" version = "3.100.2+0" [[deps.LERC_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "bf36f528eec6634efc60d7ec062008f171071434" +deps = ["Artifacts", "JLLWrappers", "Libdl"] +git-tree-sha1 = "36bdbc52f13a7d1dcb0f3cd694e01677a515655b" uuid = "88015f11-f218-50d7-93a8-a6af411a945d" -version = "3.0.0+1" +version = "4.0.0+0" [[deps.LLVM]] -deps = ["CEnum", "LLVMExtra_jll", "Libdl", "Preferences", "Printf", "Requires", "Unicode"] -git-tree-sha1 = "020abd49586480c1be84f57da0017b5d3db73f7c" +deps = ["CEnum", "LLVMExtra_jll", "Libdl", "Preferences", "Printf", "Unicode"] +git-tree-sha1 = "d422dfd9707bec6617335dc2ea3c5172a87d5908" uuid = "929cbde3-209d-540e-8aea-75f648917ca0" -version = "8.0.0" +version = "9.1.3" weakdeps = ["BFloat16s"] [deps.LLVM.extensions] @@ -870,39 +879,41 @@ weakdeps = ["BFloat16s"] [[deps.LLVMExtra_jll]] deps = ["Artifacts", "JLLWrappers", "LazyArtifacts", "Libdl", "TOML"] -git-tree-sha1 = "c2636c264861edc6d305e6b4d528f09566d24c5e" +git-tree-sha1 = "05a8bd5a42309a9ec82f700876903abce1017dd3" uuid = "dad2f222-ce93-54a1-a47d-0025e8a3acab" -version = "0.0.30+0" +version = "0.0.34+0" [[deps.LLVMOpenMP_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl"] -git-tree-sha1 = "d986ce2d884d49126836ea94ed5bfb0f12679713" +git-tree-sha1 = "78211fb6cbc872f77cad3fc0b6cf647d923f4929" uuid = "1d63c593-3942-5779-bab2-d838dc0a180e" -version = "15.0.7+0" +version = "18.1.7+0" [[deps.LZO_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl"] -git-tree-sha1 = "70c5da094887fd2cae843b8db33920bac4b6f07d" +git-tree-sha1 = "854a9c268c43b77b0a27f22d7fab8d33cdb3a731" uuid = "dd4b983a-f0e5-5f8d-a1b7-129d4a5fb1ac" -version = "2.10.2+0" +version = "2.10.2+1" [[deps.LaTeXStrings]] -git-tree-sha1 = "50901ebc375ed41dbf8058da26f9de442febbbec" +git-tree-sha1 = "dda21b8cbd6a6c40d9d02a73230f9d70fed6918c" uuid = "b964fa9f-0449-5b57-a5c2-d3ea65f4040f" -version = "1.3.1" +version = "1.4.0" [[deps.Latexify]] deps = ["Format", "InteractiveUtils", "LaTeXStrings", "MacroTools", "Markdown", "OrderedCollections", "Requires"] -git-tree-sha1 = "5b0d630f3020b82c0775a51d05895852f8506f50" +git-tree-sha1 = "ce5f5621cac23a86011836badfedf664a612cee4" uuid = "23fbe1c1-3f47-55db-b15f-69d7ec21a316" -version = "0.16.4" +version = "0.16.5" [deps.Latexify.extensions] DataFramesExt = "DataFrames" + SparseArraysExt = "SparseArrays" SymEngineExt = "SymEngine" [deps.Latexify.weakdeps] DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" + SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" SymEngine = "123dc426-2d89-5057-bbad-38513e3affd8" [[deps.LazyArtifacts]] @@ -985,9 +996,9 @@ version = "2.40.1+0" [[deps.Libtiff_jll]] deps = ["Artifacts", "JLLWrappers", "JpegTurbo_jll", "LERC_jll", "Libdl", "XZ_jll", "Zlib_jll", "Zstd_jll"] -git-tree-sha1 = "2da088d113af58221c52828a80378e16be7d037a" +git-tree-sha1 = "b404131d06f7886402758c9ce2214b636eb4d54a" uuid = "89763e89-9b03-5906-acba-b20f662cd828" -version = "4.5.1+1" +version = "4.7.0+0" [[deps.Libuuid_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl"] @@ -1030,11 +1041,51 @@ git-tree-sha1 = "1d2dd9b186742b0f317f2530ddcbf00eebb18e96" uuid = "23992714-dd62-5051-b70f-ba57cb901cac" version = "0.10.7" +[[deps.MLDataDevices]] +deps = ["Adapt", "Compat", "Functors", "LinearAlgebra", "Preferences", "Random"] +git-tree-sha1 = "3207c2e66164e6366440ad3f0243a8d67abb4a47" +uuid = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40" +version = "1.4.1" + + [deps.MLDataDevices.extensions] + MLDataDevicesAMDGPUExt = "AMDGPU" + MLDataDevicesCUDAExt = "CUDA" + MLDataDevicesChainRulesCoreExt = "ChainRulesCore" + MLDataDevicesFillArraysExt = "FillArrays" + MLDataDevicesGPUArraysExt = "GPUArrays" + MLDataDevicesMLUtilsExt = "MLUtils" + MLDataDevicesMetalExt = ["GPUArrays", "Metal"] + MLDataDevicesReactantExt = "Reactant" + MLDataDevicesRecursiveArrayToolsExt = "RecursiveArrayTools" + MLDataDevicesReverseDiffExt = "ReverseDiff" + MLDataDevicesSparseArraysExt = "SparseArrays" + MLDataDevicesTrackerExt = "Tracker" + MLDataDevicesZygoteExt = "Zygote" + MLDataDevicescuDNNExt = ["CUDA", "cuDNN"] + MLDataDevicesoneAPIExt = ["GPUArrays", "oneAPI"] + + [deps.MLDataDevices.weakdeps] + AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" + CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" + ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" + FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" + GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" + MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" + Metal = "dde4c033-4e86-420c-a63e-0dd931031962" + Reactant = "3c362404-f566-11ee-1572-e11a4b42c853" + RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" + ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" + SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" + Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" + Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" + cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" + oneAPI = "8f75cd03-7ff8-4ecb-9b8f-daf728133b1b" + [[deps.MLDatasets]] deps = ["CSV", "Chemfiles", "DataDeps", "DataFrames", "DelimitedFiles", "FileIO", "FixedPointNumbers", "GZip", "Glob", "HDF5", "ImageShow", "JLD2", "JSON3", "LazyModules", "MAT", "MLUtils", "NPZ", "Pickle", "Printf", "Requires", "SparseArrays", "Statistics", "Tables"] -git-tree-sha1 = "55ed5f79697232389d894d05e93633a03e774818" +git-tree-sha1 = "361c2692ee730944764945859f1a6b31072e275d" uuid = "eb30cadb-4394-5ae3-aed4-317e484a6458" -version = "0.7.16" +version = "0.7.18" [[deps.MLJBase]] deps = ["CategoricalArrays", "CategoricalDistributions", "ComputationalResources", "Dates", "DelimitedFiles", "Distributed", "Distributions", "InteractiveUtils", "InvertedIndices", "LearnAPI", "LinearAlgebra", "MLJModelInterface", "Missings", "OrderedCollections", "Parameters", "PrettyTables", "ProgressMeter", "Random", "RecipesBase", "Reexport", "ScientificTypes", "Serialization", "StatisticalMeasuresBase", "StatisticalTraits", "Statistics", "StatsBase", "Tables"] @@ -1054,12 +1105,6 @@ git-tree-sha1 = "90ef4d3b6cacec631c57cc034e1e61b4aa0ce511" uuid = "c6f25543-311c-4c74-83dc-3ea6d1015661" version = "0.4.2" -[[deps.MLJFlux]] -deps = ["CategoricalArrays", "ColorTypes", "ComputationalResources", "Flux", "MLJModelInterface", "Metalhead", "Optimisers", "ProgressMeter", "Random", "Statistics", "Tables"] -git-tree-sha1 = "50c7f24b84005a2a80875c10d4f4059df17a0f68" -uuid = "094fc8d1-fd35-5302-93ea-dabda2abf845" -version = "0.5.1" - [[deps.MLJModelInterface]] deps = ["Random", "ScientificTypesBase", "StatisticalTraits"] git-tree-sha1 = "ceaff6618408d0e412619321ae43b33b40c1a733" @@ -1085,9 +1130,9 @@ version = "0.4.4" [[deps.MPICH_jll]] deps = ["Artifacts", "CompilerSupportLibraries_jll", "Hwloc_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "MPIPreferences", "TOML"] -git-tree-sha1 = "19d4bd098928a3263693991500d05d74dbdc2004" +git-tree-sha1 = "7715e65c47ba3941c502bffb7f266a41a7f54423" uuid = "7cb0a576-ebde-5e09-9194-50597f1243b4" -version = "4.2.2+0" +version = "4.2.3+0" [[deps.MPIPreferences]] deps = ["Libdl", "Preferences"] @@ -1097,9 +1142,9 @@ version = "0.1.11" [[deps.MPItrampoline_jll]] deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "MPIPreferences", "TOML"] -git-tree-sha1 = "8c35d5420193841b2f367e658540e8d9e0601ed0" +git-tree-sha1 = "70e830dab5d0775183c99fc75e4c24c614ed7142" uuid = "f1f71cc9-e9ae-5b93-9b94-4fe0e1ad3748" -version = "5.4.0+0" +version = "5.5.1+0" [[deps.MacroTools]] deps = ["Markdown", "Random"] @@ -1132,18 +1177,6 @@ git-tree-sha1 = "c13304c81eec1ed3af7fc20e75fb6b26092a1102" uuid = "442fdcdd-2543-5da2-b0f3-8c86c306513e" version = "0.3.2" -[[deps.Metalhead]] -deps = ["Artifacts", "BSON", "ChainRulesCore", "Flux", "Functors", "JLD2", "LazyArtifacts", "MLUtils", "NNlib", "PartialFunctions", "Random", "Statistics"] -git-tree-sha1 = "5aac9a2b511afda7bf89df5044a2e0b429f83152" -uuid = "dbeba491-748d-5e0e-a39e-b530a07fa0cc" -version = "0.9.3" - - [deps.Metalhead.extensions] - MetalheadCUDAExt = "CUDA" - - [deps.Metalhead.weakdeps] - CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" - [[deps.MicroCollections]] deps = ["Accessors", "BangBang", "InitialValues"] git-tree-sha1 = "44d32db644e84c75dab479f1bc15ee76a1a3618f" @@ -1182,10 +1215,10 @@ uuid = "6f286f6a-111f-5878-ab1e-185364afe411" version = "0.10.3" [[deps.NNlib]] -deps = ["Adapt", "Atomix", "ChainRulesCore", "GPUArraysCore", "KernelAbstractions", "LinearAlgebra", "Pkg", "Random", "Requires", "Statistics"] -git-tree-sha1 = "190dcada8cf9520198058c4544862b1f88c6c577" +deps = ["Adapt", "Atomix", "ChainRulesCore", "GPUArraysCore", "KernelAbstractions", "LinearAlgebra", "Random", "Statistics"] +git-tree-sha1 = "da09a1e112fd75f9af2a5229323f01b56ec96a4c" uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" -version = "0.9.21" +version = "0.9.24" [deps.NNlib.extensions] NNlibAMDGPUExt = "AMDGPU" @@ -1193,12 +1226,14 @@ version = "0.9.21" NNlibCUDAExt = "CUDA" NNlibEnzymeCoreExt = "EnzymeCore" NNlibFFTWExt = "FFTW" + NNlibForwardDiffExt = "ForwardDiff" [deps.NNlib.weakdeps] AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341" + ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" [[deps.NPZ]] @@ -1255,10 +1290,10 @@ uuid = "05823500-19ac-5b8b-9628-191a04bc5112" version = "0.8.1+2" [[deps.OpenMPI_jll]] -deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "MPIPreferences", "TOML"] -git-tree-sha1 = "e25c1778a98e34219a00455d6e4384e017ea9762" +deps = ["Artifacts", "CompilerSupportLibraries_jll", "Hwloc_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "MPIPreferences", "TOML", "Zlib_jll"] +git-tree-sha1 = "bfce6d523861a6c562721b262c0d1aaeead2647f" uuid = "fe0851c0-eecd-5654-98d4-656369965a5c" -version = "4.1.6+0" +version = "5.0.5+0" [[deps.OpenSSL]] deps = ["BitFlags", "Dates", "MozillaCACerts_jll", "OpenSSL_jll", "Sockets"] @@ -1268,9 +1303,9 @@ version = "1.4.3" [[deps.OpenSSL_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl"] -git-tree-sha1 = "a028ee3cb5641cccc4c24e90c36b0a4f7707bdf5" +git-tree-sha1 = "7493f61f55a6cce7325f197443aa80d32554ba10" uuid = "458c3c95-2e84-50aa-8efc-19380b2a3a95" -version = "3.0.14+0" +version = "3.0.15+1" [[deps.OpenSpecFun_jll]] deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "Libdl", "Pkg"] @@ -1285,10 +1320,10 @@ uuid = "3bd65402-5787-11e9-1adc-39752487f4e2" version = "0.3.3" [[deps.Opus_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "51a08fb14ec28da2ec7a927c4337e4332c2a4720" +deps = ["Artifacts", "JLLWrappers", "Libdl"] +git-tree-sha1 = "6703a85cb3781bd5909d48730a67205f3f31a575" uuid = "91d4177d-7536-5919-b921-800302f37372" -version = "1.3.2+0" +version = "1.3.3+0" [[deps.OrderedCollections]] git-tree-sha1 = "dfdf5519f235516220579f949664f1bf44e741c5" @@ -1318,6 +1353,12 @@ git-tree-sha1 = "0fac6313486baae819364c52b4f483450a9d793f" uuid = "5432bcbf-9aad-5242-b902-cca2824c8663" version = "0.5.12" +[[deps.Pango_jll]] +deps = ["Artifacts", "Cairo_jll", "Fontconfig_jll", "FreeType2_jll", "FriBidi_jll", "Glib_jll", "HarfBuzz_jll", "JLLWrappers", "Libdl"] +git-tree-sha1 = "e127b609fb9ecba6f201ba7ab753d5a605d53801" +uuid = "36c8627f-9965-5494-a995-c6b170f724f3" +version = "1.54.1+0" + [[deps.Parameters]] deps = ["OrderedCollections", "UnPack"] git-tree-sha1 = "34c0e9ad262e5f7fc75b10a9952ca7692cfc5fbe" @@ -1330,12 +1371,6 @@ git-tree-sha1 = "8489905bcdbcfac64d1daa51ca07c0d8f0283821" uuid = "69de0a69-1ddd-5017-9359-2bf0b02dc9f0" version = "2.8.1" -[[deps.PartialFunctions]] -deps = ["MacroTools"] -git-tree-sha1 = "47b49a4dbc23b76682205c646252c0f9e1eb75af" -uuid = "570af359-4316-4cb7-8c74-252c00c2016b" -version = "1.2.0" - [[deps.PeriodicTable]] deps = ["Base64", "Unitful"] git-tree-sha1 = "238aa6298007565529f911b734e18addd56985e1" @@ -1371,16 +1406,16 @@ uuid = "ccf2f8ad-2431-5c83-bf29-c5338b663b6a" version = "3.2.0" [[deps.PlotUtils]] -deps = ["ColorSchemes", "Colors", "Dates", "PrecompileTools", "Printf", "Random", "Reexport", "Statistics"] -git-tree-sha1 = "7b1a9df27f072ac4c9c7cbe5efb198489258d1f5" +deps = ["ColorSchemes", "Colors", "Dates", "PrecompileTools", "Printf", "Random", "Reexport", "StableRNGs", "Statistics"] +git-tree-sha1 = "650a022b2ce86c7dcfbdecf00f78afeeb20e5655" uuid = "995b91a9-d308-5afd-9ec6-746e21dbc043" -version = "1.4.1" +version = "1.4.2" [[deps.Plots]] deps = ["Base64", "Contour", "Dates", "Downloads", "FFMPEG", "FixedPointNumbers", "GR", "JLFzf", "JSON", "LaTeXStrings", "Latexify", "LinearAlgebra", "Measures", "NaNMath", "Pkg", "PlotThemes", "PlotUtils", "PrecompileTools", "Printf", "REPL", "Random", "RecipesBase", "RecipesPipeline", "Reexport", "RelocatableFolders", "Requires", "Scratch", "Showoff", "SparseArrays", "Statistics", "StatsBase", "TOML", "UUIDs", "UnicodeFun", "UnitfulLatexify", "Unzip"] -git-tree-sha1 = "082f0c4b70c202c37784ce4bfbc33c9f437685bf" +git-tree-sha1 = "45470145863035bb124ca51b320ed35d071cc6c2" uuid = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" -version = "1.40.5" +version = "1.40.8" [deps.Plots.extensions] FileIOExt = "FileIO" @@ -1426,9 +1461,9 @@ version = "0.4.2" [[deps.PrettyTables]] deps = ["Crayons", "LaTeXStrings", "Markdown", "PrecompileTools", "Printf", "Reexport", "StringManipulation", "Tables"] -git-tree-sha1 = "66b20dd35966a748321d3b2537c4584cf40387c7" +git-tree-sha1 = "1101cd475833706e4d0e7b122218257178f48f34" uuid = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d" -version = "2.3.2" +version = "2.4.0" [[deps.Printf]] deps = ["Unicode"] @@ -1447,9 +1482,9 @@ uuid = "92933f4c-e287-5a05-a399-4b506db050ca" version = "1.10.2" [[deps.PtrArrays]] -git-tree-sha1 = "f011fbb92c4d401059b2212c05c0601b70f8b759" +git-tree-sha1 = "77a42d78b6a92df47ab37e177b2deac405e1c88f" uuid = "43287f4e-b6f4-7ad1-bb20-aadabca52c3d" -version = "1.2.0" +version = "1.2.1" [[deps.Qt6Base_jll]] deps = ["Artifacts", "CompilerSupportLibraries_jll", "Fontconfig_jll", "Glib_jll", "JLLWrappers", "Libdl", "Libglvnd_jll", "OpenSSL_jll", "Vulkan_Loader_jll", "Xorg_libSM_jll", "Xorg_libXext_jll", "Xorg_libXrender_jll", "Xorg_libxcb_jll", "Xorg_xcb_util_cursor_jll", "Xorg_xcb_util_image_jll", "Xorg_xcb_util_keysyms_jll", "Xorg_xcb_util_renderutil_jll", "Xorg_xcb_util_wm_jll", "Zlib_jll", "libinput_jll", "xkbcommon_jll"] @@ -1477,9 +1512,15 @@ version = "6.7.1+1" [[deps.QuadGK]] deps = ["DataStructures", "LinearAlgebra"] -git-tree-sha1 = "9b23c31e76e333e6fb4c1595ae6afa74966a729e" +git-tree-sha1 = "cda3b045cf9ef07a08ad46731f5a3165e56cf3da" uuid = "1fd47b50-473d-5c70-9696-f719f8f3bcdc" -version = "2.9.4" +version = "2.11.1" + + [deps.QuadGK.extensions] + QuadGKEnzymeExt = "Enzyme" + + [deps.QuadGK.weakdeps] + Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" [[deps.REPL]] deps = ["InteractiveUtils", "Markdown", "Sockets", "Unicode"] @@ -1526,15 +1567,15 @@ version = "1.3.0" [[deps.Rmath]] deps = ["Random", "Rmath_jll"] -git-tree-sha1 = "f65dcb5fa46aee0cf9ed6274ccbd597adc49aa7b" +git-tree-sha1 = "852bd0f55565a9e973fcfee83a84413270224dc4" uuid = "79098fc4-a85e-5d69-aa6a-4863f24498fa" -version = "0.7.1" +version = "0.8.0" [[deps.Rmath_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl"] -git-tree-sha1 = "d483cd324ce5cf5d61b77930f0bbd6cb61927d21" +git-tree-sha1 = "58cdd8fb2201a6267e1db87ff148dd6c1dbd8ad8" uuid = "f50d1b31-88e8-58de-be2c-1cc44531875f" -version = "0.4.2+0" +version = "0.5.1+0" [[deps.SHA]] uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce" @@ -1590,9 +1631,9 @@ uuid = "992d4aef-0814-514b-bc4d-f2e9a6c4116f" version = "1.0.3" [[deps.SimpleBufferStream]] -git-tree-sha1 = "874e8867b33a00e784c8a7e4b60afe9e037b74e1" +git-tree-sha1 = "f305871d2f381d21527c770d4788c06c097c9bc1" uuid = "777ac1f9-54b0-4bf8-805c-2214025038e7" -version = "1.1.0" +version = "1.2.0" [[deps.SimpleTraits]] deps = ["InteractiveUtils", "MacroTools"] @@ -1650,9 +1691,9 @@ version = "0.1.1" [[deps.StaticArrays]] deps = ["LinearAlgebra", "PrecompileTools", "Random", "StaticArraysCore"] -git-tree-sha1 = "eeafab08ae20c62c44c8399ccb9354a04b80db50" +git-tree-sha1 = "777657803913ffc7e8cc20f0fd04b634f871af8f" uuid = "90137ffa-7385-5640-81b9-e52037218182" -version = "1.9.7" +version = "1.9.8" weakdeps = ["ChainRulesCore", "Statistics"] [deps.StaticArrays.extensions] @@ -1695,9 +1736,9 @@ version = "0.34.3" [[deps.StatsFuns]] deps = ["HypergeometricFunctions", "IrrationalConstants", "LogExpFunctions", "Reexport", "Rmath", "SpecialFunctions"] -git-tree-sha1 = "cef0472124fab0695b58ca35a77c6fb942fdab8a" +git-tree-sha1 = "b423576adc27097764a90e163157bcfc9acf0f46" uuid = "4c63d2b9-4356-54db-8cca-17b64c39e42c" -version = "1.3.1" +version = "1.3.2" weakdeps = ["ChainRulesCore", "InverseFunctions"] [deps.StatsFuns.extensions] @@ -1724,9 +1765,9 @@ version = "0.3.7" [[deps.StringManipulation]] deps = ["PrecompileTools"] -git-tree-sha1 = "a04cabe79c5f01f4d723cc6704070ada0b9d46d5" +git-tree-sha1 = "a6b1675a536c5ad1a60e5a5153e1fee12eb146e3" uuid = "892a3eda-7b42-436c-8928-eab12a02cf0e" -version = "0.3.4" +version = "0.4.0" [[deps.StructArrays]] deps = ["ConstructionBase", "DataAPI", "Tables"] @@ -1743,9 +1784,9 @@ weakdeps = ["Adapt", "GPUArraysCore", "SparseArrays", "StaticArrays"] [[deps.StructTypes]] deps = ["Dates", "UUIDs"] -git-tree-sha1 = "ca4bccb03acf9faaf4137a9abc1881ed1841aa70" +git-tree-sha1 = "159331b30e94d7b11379037feeb9b690950cace8" uuid = "856f2bd8-1eba-4b0a-8007-ebc267875bd4" -version = "1.10.0" +version = "1.11.0" [[deps.SuiteSparse]] deps = ["Libdl", "LinearAlgebra", "Serialization", "SparseArrays"] @@ -1774,10 +1815,9 @@ uuid = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" version = "1.12.0" [[deps.TaijaBase]] -deps = ["CategoricalArrays", "Distributions", "Flux", "MLUtils", "Optimisers", "StatsBase", "Tables"] -git-tree-sha1 = "1c80c4472c6ab6e8c9fa544a22d907295b388dd0" +git-tree-sha1 = "4076f60078b12095ca71a2c26e2e4515e3a6a5e5" uuid = "10284c91-9f28-4c9a-abbf-ee43576dfff6" -version = "1.2.2" +version = "1.2.3" [[deps.TaijaData]] deps = ["CSV", "CounterfactualExplanations", "DataAPI", "DataFrames", "Flux", "LazyArtifacts", "MLDatasets", "MLJBase", "MLJModels", "Random", "StatsBase"] @@ -1801,21 +1841,18 @@ deps = ["InteractiveUtils", "Logging", "Random", "Serialization"] uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [[deps.TranscodingStreams]] -git-tree-sha1 = "96612ac5365777520c3c5396314c8cf7408f436a" +git-tree-sha1 = "0c45878dcfdcfa8480052b6ab162cdd138781742" uuid = "3bb67fe8-82b1-5028-8e26-92a6c54297fa" -version = "0.11.1" -weakdeps = ["Random", "Test"] - - [deps.TranscodingStreams.extensions] - TestExt = ["Test", "Random"] +version = "0.11.3" [[deps.Transducers]] -deps = ["Accessors", "Adapt", "ArgCheck", "BangBang", "Baselet", "CompositionsBase", "ConstructionBase", "DefineSingletons", "Distributed", "InitialValues", "Logging", "Markdown", "MicroCollections", "Requires", "SplittablesBase", "Tables"] -git-tree-sha1 = "5215a069867476fc8e3469602006b9670e68da23" +deps = ["Accessors", "ArgCheck", "BangBang", "Baselet", "CompositionsBase", "ConstructionBase", "DefineSingletons", "Distributed", "InitialValues", "Logging", "Markdown", "MicroCollections", "Requires", "SplittablesBase", "Tables"] +git-tree-sha1 = "7deeab4ff96b85c5f72c824cae53a1398da3d1cb" uuid = "28d57a85-8fef-5791-bfe6-a80928e7c999" -version = "0.4.82" +version = "0.4.84" [deps.Transducers.extensions] + TransducersAdaptExt = "Adapt" TransducersBlockArraysExt = "BlockArrays" TransducersDataFramesExt = "DataFrames" TransducersLazyArraysExt = "LazyArrays" @@ -1823,6 +1860,7 @@ version = "0.4.82" TransducersReferenceablesExt = "Referenceables" [deps.Transducers.weakdeps] + Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e" DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" LazyArrays = "5078a376-72f3-5289-bfd5-ec5146d43c02" @@ -1905,9 +1943,9 @@ version = "0.2.1" [[deps.UnsafeAtomicsLLVM]] deps = ["LLVM", "UnsafeAtomics"] -git-tree-sha1 = "bf2c553f25e954a9b38c9c0593a59bb13113f9e5" +git-tree-sha1 = "2d17fabcd17e67d7625ce9c531fb9f40b7c42ce4" uuid = "d80eeb9a-aca5-4d75-85e5-170c8b632249" -version = "0.1.5" +version = "0.2.1" [[deps.Unzip]] git-tree-sha1 = "ca0969166a028236229f63514992fc073799bb78" @@ -1945,9 +1983,9 @@ version = "1.6.1" [[deps.XML2_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Libiconv_jll", "Zlib_jll"] -git-tree-sha1 = "d9717ce3518dc68a99e6b96300813760d887a01d" +git-tree-sha1 = "1165b0443d0eca63ac1e32b8c0eb69ed2f4f8127" uuid = "02c8fc9c-b97f-50b9-bbe4-9be30ff0a78a" -version = "2.13.1+0" +version = "2.13.3+0" [[deps.XSLT_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Libgcrypt_jll", "Libgpg_error_jll", "Libiconv_jll", "XML2_jll", "Zlib_jll"] @@ -2118,15 +2156,15 @@ version = "1.2.13+1" [[deps.Zstd_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl"] -git-tree-sha1 = "e678132f07ddb5bfa46857f0d7620fb9be675d3b" +git-tree-sha1 = "555d1076590a6cc2fdee2ef1469451f872d8b41b" uuid = "3161d3a3-bdf6-5164-811a-617609db77b4" -version = "1.5.6+0" +version = "1.5.6+1" [[deps.Zygote]] deps = ["AbstractFFTs", "ChainRules", "ChainRulesCore", "DiffRules", "Distributed", "FillArrays", "ForwardDiff", "GPUArrays", "GPUArraysCore", "IRTools", "InteractiveUtils", "LinearAlgebra", "LogExpFunctions", "MacroTools", "NaNMath", "PrecompileTools", "Random", "Requires", "SparseArrays", "SpecialFunctions", "Statistics", "ZygoteRules"] -git-tree-sha1 = "19c586905e78a26f7e4e97f81716057bd6b1bc54" +git-tree-sha1 = "f816633be6dc5c0ed9ffedda157ecfda0b3b6a69" uuid = "e88e6eb3-aa80-5325-afca-941959d7151f" -version = "0.6.70" +version = "0.6.72" [deps.Zygote.extensions] ZygoteColorsExt = "Colors" @@ -2152,9 +2190,9 @@ version = "3.2.9+0" [[deps.fzf_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl"] -git-tree-sha1 = "a68c9655fbe6dfcab3d972808f1aafec151ce3f8" +git-tree-sha1 = "936081b536ae4aa65415d869287d43ef3cb576b2" uuid = "214eeab7-80f7-51ab-84ad-2988db7cef09" -version = "0.43.0+0" +version = "0.53.0+0" [[deps.gperf_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] @@ -2175,15 +2213,21 @@ uuid = "a4ae2306-e953-59d6-aa16-d00cac43593b" version = "3.9.0+0" [[deps.libass_jll]] -deps = ["Artifacts", "Bzip2_jll", "FreeType2_jll", "FriBidi_jll", "HarfBuzz_jll", "JLLWrappers", "Libdl", "Pkg", "Zlib_jll"] -git-tree-sha1 = "5982a94fcba20f02f42ace44b9894ee2b140fe47" +deps = ["Artifacts", "Bzip2_jll", "FreeType2_jll", "FriBidi_jll", "HarfBuzz_jll", "JLLWrappers", "Libdl", "Zlib_jll"] +git-tree-sha1 = "e17c115d55c5fbb7e52ebedb427a0dca79d4484e" uuid = "0ac62f75-1d6f-5e53-bd7c-93b484bb37c0" -version = "0.15.1+0" +version = "0.15.2+0" [[deps.libblastrampoline_jll]] deps = ["Artifacts", "Libdl"] uuid = "8e850b90-86db-534c-a0d3-1478176c7d93" -version = "5.8.0+1" +version = "5.11.0+0" + +[[deps.libdecor_jll]] +deps = ["Artifacts", "Dbus_jll", "JLLWrappers", "Libdl", "Libglvnd_jll", "Pango_jll", "Wayland_jll", "xkbcommon_jll"] +git-tree-sha1 = "9bf7903af251d2050b467f76bdbe57ce541f7f4f" +uuid = "1183f4f0-6f2a-5f1a-908b-139f9cdfea6f" +version = "0.2.2+0" [[deps.libevdev_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] @@ -2192,10 +2236,10 @@ uuid = "2db6ffa8-e38f-5e21-84af-90c45d0032cc" version = "1.11.0+0" [[deps.libfdk_aac_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "daacc84a041563f965be61859a36e17c4e4fcd55" +deps = ["Artifacts", "JLLWrappers", "Libdl"] +git-tree-sha1 = "8a22cf860a7d27e4f3498a0fe0811a7957badb38" uuid = "f638f0a6-7fb0-5443-88ba-1cc74229b280" -version = "2.0.2+0" +version = "2.0.3+0" [[deps.libinput_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg", "eudev_jll", "libevdev_jll", "mtdev_jll"] @@ -2205,15 +2249,15 @@ version = "1.18.0+0" [[deps.libpng_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Zlib_jll"] -git-tree-sha1 = "d7015d2e18a5fd9a4f47de711837e980519781a4" +git-tree-sha1 = "b70c870239dc3d7bc094eb2d6be9b73d27bef280" uuid = "b53b4c65-9356-5827-b1ea-8c7a1a84506f" -version = "1.6.43+1" +version = "1.6.44+0" [[deps.libvorbis_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Ogg_jll", "Pkg"] -git-tree-sha1 = "b910cb81ef3fe6e78bf6acee440bda86fd6ae00c" +git-tree-sha1 = "490376214c4721cdaca654041f635213c6165cb3" uuid = "f27f6e37-5d2b-51aa-960f-b287f2bc3b7a" -version = "1.3.7+1" +version = "1.3.7+2" [[deps.mtdev_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] diff --git a/test/Project.toml b/test/Project.toml index 750ea47e..1e853790 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -8,7 +8,7 @@ Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" JSON = "682c06a0-de6a-54ab-a142-c8b1cf79cde6" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d" -MLJFlux = "094fc8d1-fd35-5302-93ea-dabda2abf845" +MLJModelInterface = "e80e1ace-859a-464e-9ed9-23947d8ae3ea" MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" diff --git a/test/laplace.jl b/test/laplace.jl index 71089782..ea4cfbbe 100644 --- a/test/laplace.jl +++ b/test/laplace.jl @@ -259,7 +259,7 @@ end @test LaplaceRedux.has_softmax_or_sigmoid_final_layer(model) == false end -function train_nn(val::Dict; verbose=false) +function train_nn(val::Dict; verbosity=0) # Unpack: X = val[:X] Y = val[:Y] @@ -291,7 +291,7 @@ function train_nn(val::Dict; verbose=false) end update!(opt, Flux.params(nn), gs) end - if verbose && epoch % show_every == 0 + if verbosity>0 && epoch % show_every == 0 println("Epoch " * string(epoch)) @show avg_loss(data) end @@ -306,7 +306,7 @@ function run_workflow( backend::Symbol, subset_of_weights::Symbol; hessian_structure=:full, - verbose::Bool=false, + verbosity::Int=0, do_optimize_prior::Bool=true, do_predict::Bool=true, ) @@ -342,7 +342,7 @@ function run_workflow( ) fit!(la, data) if do_optimize_prior - optimize_prior!(la; verbose=verbose) + optimize_prior!(la; verbosity=verbosity) end if do_predict predict(la, X) diff --git a/test/mlj_flux_interfacing.jl b/test/mlj_flux_interfacing.jl deleted file mode 100644 index cb238daf..00000000 --- a/test/mlj_flux_interfacing.jl +++ /dev/null @@ -1,205 +0,0 @@ -using Random: Random -import Random.seed! -using MLJBase: MLJBase, categorical -using MLJFlux -using Flux -using StableRNGs - -@testset "Regression" begin - function basictest_regression(X, y, builder, optimiser, threshold) - optimiser = deepcopy(optimiser) - - stable_rng = StableRNGs.StableRNG(123) - - model = LaplaceRegression(; - builder=builder, - optimiser=optimiser, - acceleration=MLJBase.CPUThreads(), - loss=Flux.Losses.mse, - rng=stable_rng, - lambda=-1.0, - alpha=-1.0, - epochs=-1, - batch_size=-1, - subset_of_weights=:incorrect, - hessian_structure=:incorrect, - backend=:incorrect, - ret_distr=true, - ) - - fitresult, cache, _report = MLJBase.fit(model, 0, X, y) - - history = _report.training_losses - @test length(history) == model.epochs + 1 - - # increase iterations and check update is incremental: - model.epochs = model.epochs + 3 - - fitresult, cache, _report = @test_logs( - (:info, r""), # one line of :info per extra epoch - (:info, r""), - (:info, r""), - MLJBase.update(model, 2, fitresult, cache, X, y) - ) - - @test :chain in keys(MLJBase.fitted_params(model, fitresult)) - - history = _report.training_losses - @test length(history) == model.epochs + 1 - - yhat = MLJBase.predict(model, fitresult, X) - - # start fresh with small epochs: - model = LaplaceRegression(; - builder=builder, - optimiser=optimiser, - epochs=2, - acceleration=CPU1(), - rng=stable_rng, - ) - - fitresult, cache, _report = MLJBase.fit(model, 0, X, y) - - # change batch_size and check it performs cold restart: - model.batch_size = 2 - fitresult, cache, _report = @test_logs( - (:info, r""), # one line of :info per extra epoch - (:info, r""), - MLJBase.update(model, 2, fitresult, cache, X, y) - ) - - # change learning rate and check it does *not* restart: - model.optimiser.eta /= 2 - fitresult, cache, _report = @test_logs( - MLJBase.update(model, 2, fitresult, cache, X, y) - ) - - # set `optimiser_changes_trigger_retraining = true` and change - # learning rate and check it does restart: - model.optimiser_changes_trigger_retraining = true - model.optimiser.eta /= 2 - @test_logs( - (:info, r""), # one line of :info per extra epoch - (:info, r""), - MLJBase.update(model, 2, fitresult, cache, X, y) - ) - - return true - end - - seed!(1234) - N = 300 - X = MLJBase.table(rand(Float32, N, 4)) - ycont = 2 * X.x1 - X.x3 + 0.1 * rand(N) - builder = MLJFlux.MLP(; hidden=(16, 8), σ=Flux.relu) - optimiser = Flux.Optimise.Adam(0.03) - - @test basictest_regression(X, ycont, builder, optimiser, 0.9) -end - -@testset "Classification" begin - function basictest_classification(X, y, builder, optimiser, threshold) - optimiser = deepcopy(optimiser) - - stable_rng = StableRNGs.StableRNG(123) - - model = LaplaceClassification(; - builder=builder, - optimiser=optimiser, - loss=Flux.crossentropy, - epochs=-1, - batch_size=-1, - lambda=-1.0, - alpha=-1.0, - rng=stable_rng, - acceleration=MLJBase.CPUThreads(), - subset_of_weights=:incorrect, - hessian_structure=:incorrect, - backend=:incorrect, - link_approx=:incorrect, - ) - - # Test that shape is correct: - @test MLJFlux.shape(model, X, y)[2] == length(unique(y)) - fitresult, cache, _report = MLJBase.fit(model, 0, X, y) - - history = _report.training_losses - @test length(history) == model.epochs + 1 - - # increase iterations and check update is incremental: - model.epochs = model.epochs + 3 - - fitresult, cache, _report = @test_logs( - (:info, r""), # one line of :info per extra epoch - (:info, r""), - (:info, r""), - MLJBase.update(model, 2, fitresult, cache, X, y) - ) - - @test :chain in keys(MLJBase.fitted_params(model, fitresult)) - - yhat = MLJBase.predict(model, fitresult, X) - - history = _report.training_losses - @test length(history) == model.epochs + 1 - - # start fresh with small epochs: - model = LaplaceClassification(; - builder=builder, - optimiser=optimiser, - epochs=2, - acceleration=CPU1(), - rng=stable_rng, - ) - - fitresult, cache, _report = MLJBase.fit(model, 0, X, y) - - # change batch_size and check it performs cold restart: - model.batch_size = 2 - fitresult, cache, _report = @test_logs( - (:info, r""), # one line of :info per extra epoch - (:info, r""), - MLJBase.update(model, 2, fitresult, cache, X, y) - ) - - # change learning rate and check it does *not* restart: - model.optimiser.eta /= 2 - fitresult, cache, _report = @test_logs( - MLJBase.update(model, 2, fitresult, cache, X, y) - ) - - # set `optimiser_changes_trigger_retraining = true` and change - # learning rate and check it does restart: - model.optimiser_changes_trigger_retraining = true - model.optimiser.eta /= 2 - @test_logs( - (:info, r""), # one line of :info per extra epoch - (:info, r""), - MLJBase.update(model, 2, fitresult, cache, X, y) - ) - - return true - end - - seed!(1234) - N = 300 - X = MLJBase.table(rand(Float32, N, 4)) - ycont = 2 * X.x1 - X.x3 + 0.1 * rand(N) - m, M = minimum(ycont), maximum(ycont) - _, a, b, _ = collect(range(m; stop=M, length=4)) - y = categorical( - map(ycont) do η - if η < 0.9 * a - 'a' - elseif η < 1.1 * b - 'b' - else - 'c' - end - end, - ) - - builder = MLJFlux.MLP(; hidden=(16, 8), σ=Flux.relu) - optimizer = Flux.Optimise.Adam(0.03) - @test basictest_classification(X, y, builder, optimizer, 0.9) -end diff --git a/test/runtests.jl b/test/runtests.jl index 0459cf5a..92d97033 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -35,7 +35,4 @@ using Test include("krondecomposed.jl") end - @testset "MLJFlux" begin - include("mlj_flux_interfacing.jl") - end end