From ed306354e0163a330510fc218dfac40a5f0ee4b1 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 24 May 2022 00:28:20 -0400 Subject: [PATCH 1/6] enforce SciMLStyle --- .JuliaFormatter.toml | 2 + .github/workflows/FormatCheck.yml | 43 ++++++ .github/workflows/FormatPR.yml | 29 ++++ docs/make.jl | 19 ++- docs/src/design/core.md | 19 +++ examples/Basics/main.jl | 17 ++- examples/BayesianNN/main.jl | 21 ++- examples/ImageNet/main.jl | 214 ++++++++++++++------------- examples/NeuralODE/main.jl | 75 +++++----- examples/SimpleRNN/main.jl | 35 +++-- src/Lux.jl | 5 +- src/adapt.jl | 24 +-- src/autodiff.jl | 41 ++--- src/core.jl | 49 +++--- src/layers/basic.jl | 186 +++++++++++++++-------- src/layers/conv.jl | 144 ++++++++++-------- src/layers/display.jl | 34 +++-- src/layers/dropout.jl | 20 +-- src/layers/normalize.jl | 238 +++++++++++++++--------------- src/layers/recurrent.jl | 137 +++++++++-------- src/nnlib.jl | 183 ++++++++++++----------- src/transform.jl | 27 ++-- src/utils.jl | 63 +++++--- test/layers/basic.jl | 90 ++++++----- test/layers/conv.jl | 137 +++++++++-------- test/layers/dropout.jl | 22 +-- test/layers/normalize.jl | 62 ++++---- test/layers/recurrent.jl | 62 ++++---- test/models/convnets.jl | 108 ++++++-------- test/runtests.jl | 4 +- test/utils.jl | 13 +- 31 files changed, 1198 insertions(+), 925 deletions(-) create mode 100644 .JuliaFormatter.toml create mode 100644 .github/workflows/FormatCheck.yml create mode 100644 .github/workflows/FormatPR.yml create mode 100644 docs/src/design/core.md diff --git a/.JuliaFormatter.toml b/.JuliaFormatter.toml new file mode 100644 index 000000000..93a9e7665 --- /dev/null +++ b/.JuliaFormatter.toml @@ -0,0 +1,2 @@ +style = "sciml" +whitespace_in_kwargs = true diff --git a/.github/workflows/FormatCheck.yml b/.github/workflows/FormatCheck.yml new file mode 100644 index 000000000..c5aa42e41 --- /dev/null +++ b/.github/workflows/FormatCheck.yml @@ -0,0 +1,43 @@ +name: format-check + +on: + push: + branches: + - 'main' + - 'release-' + tags: '*' + pull_request: + +jobs: + build: + runs-on: ${{ matrix.os }} + strategy: + matrix: + julia-version: [1.3.0] + julia-arch: [x86] + os: [ubuntu-latest] + steps: + - uses: julia-actions/setup-julia@latest + with: + version: ${{ matrix.julia-version }} + + - uses: actions/checkout@v1 + - name: Install JuliaFormatter and format + # This will use the latest version by default but you can set the version like so: + # + # julia -e 'using Pkg; Pkg.add(PackageSpec(name="JuliaFormatter", version="0.13.0"))' + run: | + julia -e 'using Pkg; Pkg.add(PackageSpec(name="JuliaFormatter"))' + julia -e 'using JuliaFormatter; format(".", verbose=true)' + - name: Format check + run: | + julia -e ' + out = Cmd(`git diff --name-only`) |> read |> String + if out == "" + exit(0) + else + @error "Some files have not been formatted !!!" + write(stdout, out) + exit(1) + end' + \ No newline at end of file diff --git a/.github/workflows/FormatPR.yml b/.github/workflows/FormatPR.yml new file mode 100644 index 000000000..3a4c959aa --- /dev/null +++ b/.github/workflows/FormatPR.yml @@ -0,0 +1,29 @@ +name: format-pr +on: + schedule: + - cron: '0 0 * * *' +jobs: + build: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v2 + - name: Install JuliaFormatter and format + run: | + julia -e 'import Pkg; Pkg.add("JuliaFormatter")' + julia -e 'using JuliaFormatter; format(".")' + # https://github.com/marketplace/actions/create-pull-request + # https://github.com/peter-evans/create-pull-request#reference-example + - name: Create Pull Request + id: cpr + uses: peter-evans/create-pull-request@v3 + with: + token: ${{ secrets.GITHUB_TOKEN }} + commit-message: Format .jl files + title: 'Automatic JuliaFormatter.jl run' + branch: auto-juliaformatter-pr + delete-branch: true + labels: formatting, automated pr, no changelog + - name: Check outputs + run: | + echo "Pull Request Number - ${{ steps.cpr.outputs.pull-request-number }}" + echo "Pull Request URL - ${{ steps.cpr.outputs.pull-request-url }}" \ No newline at end of file diff --git a/docs/make.jl b/docs/make.jl index 2dc70db62..4fd8b12bd 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -2,7 +2,7 @@ using Documenter, Lux, Literate, Pkg # Precompile example dependencies Pkg.activate(joinpath(@__DIR__, "..", "examples")) -Pkg.develop(PackageSpec(; path=joinpath(@__DIR__, ".."))) +Pkg.develop(PackageSpec(; path = joinpath(@__DIR__, ".."))) Pkg.instantiate() Pkg.precompile() Pkg.activate(@__DIR__) @@ -13,7 +13,8 @@ if haskey(ENV, "GITHUB_ACTIONS") end deployconfig = Documenter.auto_detect_deploy_system() -Documenter.post_status(deployconfig; type="pending", repo="github.com/avik-pal/Lux.jl.git") +Documenter.post_status(deployconfig; type = "pending", + repo = "github.com/avik-pal/Lux.jl.git") # Tutorials get_example_path(p) = joinpath(@__DIR__, "..", "examples", p) @@ -28,10 +29,15 @@ ADVANCED_TUTORIAL_NAMES = [] MAPPING = Dict("beginner" => [], "intermediate" => [], "advanced" => []) for (d, names, paths) in - (("beginner", BEGINNER_TUTORIAL_NAMES, BEGINNER_TUTORIALS), ("intermediate", INTERMEDIATE_TUTORIAL_NAMES, INTERMEDIATE_TUTORIALS), ("advanced", ADVANCED_TUTORIAL_NAMES, ADVANCED_TUTORIALS)) + (("beginner", BEGINNER_TUTORIAL_NAMES, BEGINNER_TUTORIALS), + ("intermediate", INTERMEDIATE_TUTORIAL_NAMES, INTERMEDIATE_TUTORIALS), + ("advanced", ADVANCED_TUTORIAL_NAMES, ADVANCED_TUTORIALS)) for (n, p) in zip(names, paths) - Literate.markdown(get_example_path(p), joinpath(OUTPUT, d, dirname(p)); documenter=true) - push!(MAPPING[d], n => joinpath("examples/generated", d, dirname(p), splitext(basename(p))[1] * ".md")) + Literate.markdown(get_example_path(p), joinpath(OUTPUT, d, dirname(p)); + documenter = true) + push!(MAPPING[d], + n => joinpath("examples/generated", d, dirname(p), + splitext(basename(p))[1] * ".md")) end end @@ -70,6 +76,7 @@ makedocs(; ], ) -deploydocs(; repo="github.com/avik-pal/Lux.jl.git", push_preview=true, devbranch="main") +deploydocs(; repo = "github.com/avik-pal/Lux.jl.git", push_preview = true, + devbranch = "main") Pkg.activate(@__DIR__) diff --git a/docs/src/design/core.md b/docs/src/design/core.md new file mode 100644 index 000000000..3e1ffa4fd --- /dev/null +++ b/docs/src/design/core.md @@ -0,0 +1,19 @@ +# Adding New Functionality/Layers + +For Style we try to follow [SciMLStyle](https://github.com/SciML/SciMLStyle). The only reason we don't have a badge yet, is we haven't yet updated the package to followed all the guidelines. Here, I am documenting some additional guidelines we enforce: + +## Mutability + +See https://github.com/SciML/SciMLStyle#out-of-place-and-immutability-is-preferred-when-sufficient-performant for reference. This is strictly enforced, i.e. all layers/functions provided as part of the external API must be pure functions, even if they come with a performance penalty. + +## Branching -- Generated Functions + +Zygote doesn't like branches in code. Like it or not, we are stuck with it for the near future. Even if julia is able to optimize branches away, Zygote will most certainly throw away those optimizations (these can be tested via `Zygote.@code_ir`). + +### Writing efficient non-branching code to make Zygote happy + +* Rely on `@generated` functions to remove **most** runtime branching. Certain examples: + * Layers behaving differently during training and inference -- we know at compile-time whether a layer is being run in training/inference mode via `istraining(st)`. + * Composite Layers relying on a variable number of internal layers -- Again we know the length of the number of internal layers at compile time. Hence we can manually unroll the loops. See [`Parallel`](@ref), [`Chain`](@ref), etc. +* Pass around `Val` in state. `Flux.jl` sets `training` to be `(:auto, true, false)`. Hence, which branch will be evaluated, will have to be determined at runtime time (*bad*). Instead if we pass `Val(true)`, we will be able to specialize functions directly based on `true`, `false`, etc. ensuring there is no runtime cost for these operations. See [`BatchNorm`](@ref), [`Dropout`](@ref), etc. + diff --git a/examples/Basics/main.jl b/examples/Basics/main.jl index d21da7587..927f352ff 100644 --- a/examples/Basics/main.jl +++ b/examples/Basics/main.jl @@ -68,11 +68,11 @@ x .+ 1 # We can see Julia tile the column vector `1:5` across all rows of the larger array. -zeros(5,5) .+ (1:5) +zeros(5, 5) .+ (1:5) # The x' syntax is used to transpose a column `1:5` into an equivalent row, and Julia will tile that across columns. -zeros(5,5) .+ (1:5)' +zeros(5, 5) .+ (1:5)' # We can use this to make a times table. @@ -121,13 +121,13 @@ Random.seed!(rng, 0) # First, let us run a random number generator 3 times with the `replicate`d rng -for i = 1:3 +for i in 1:3 println("Iteration $i ", rand(Lux.replicate(rng), 10)) end # As expected we get the same output. We can remove the `replicate` call and we will get different outputs -for i = 1:3 +for i in 1:3 println("Iteration $i ", rand(rng, 10)) end @@ -155,8 +155,10 @@ v = randn(rng, Float32, 4) # Let's use AbstractDifferentiation and Zygote to compute the gradients println("Actual Gradient: ", ∇f(v)) -println("Computed Gradient via Reverse Mode AD (Zygote): ", AD.gradient(AD.ZygoteBackend(), f, v)[1]) -println("Computed Gradient via Forward Mode AD (ForwardDiff): ", AD.gradient(AD.ForwardDiffBackend(), f, v)[1]) +println("Computed Gradient via Reverse Mode AD (Zygote): ", + AD.gradient(AD.ZygoteBackend(), f, v)[1]) +println("Computed Gradient via Forward Mode AD (ForwardDiff): ", + AD.gradient(AD.ForwardDiffBackend(), f, v)[1]) # Note that `AD.gradient` will only work for scalar valued outputs @@ -248,6 +250,7 @@ for i in 1:100 ## Perform parameter update opt_state, ps = Optimisers.update(opt_state, ps, gs) if i % 10 == 1 || i == 100 - println("Loss Value after $i iterations: ", mse(model, ps, st, x_samples, y_samples)) + println("Loss Value after $i iterations: ", + mse(model, ps, st, x_samples, y_samples)) end end diff --git a/examples/BayesianNN/main.jl b/examples/BayesianNN/main.jl index 175d6581b..8f7d098aa 100644 --- a/examples/BayesianNN/main.jl +++ b/examples/BayesianNN/main.jl @@ -32,14 +32,14 @@ x2s = rand(rng, Float32, M) * 4.5f0; xt1s = Array([[x1s[i] + 0.5f0; x2s[i] + 0.5f0] for i in 1:M]) x1s = rand(rng, Float32, M) * 4.5f0; x2s = rand(rng, Float32, M) * 4.5f0; -append!(xt1s, Array([[x1s[i] - 5f0; x2s[i] - 5f0] for i in 1:M])) +append!(xt1s, Array([[x1s[i] - 5.0f0; x2s[i] - 5.0f0] for i in 1:M])) x1s = rand(rng, Float32, M) * 4.5f0; x2s = rand(rng, Float32, M) * 4.5f0; -xt0s = Array([[x1s[i] + 0.5f0; x2s[i] - 5f0] for i in 1:M]) +xt0s = Array([[x1s[i] + 0.5f0; x2s[i] - 5.0f0] for i in 1:M]) x1s = rand(rng, Float32, M) * 4.5f0; x2s = rand(rng, Float32, M) * 4.5f0; -append!(xt0s, Array([[x1s[i] - 5f0; x2s[i] + 0.5f0] for i in 1:M])) +append!(xt0s, Array([[x1s[i] - 5.0f0; x2s[i] + 0.5f0] for i in 1:M])) ## Store all the data for later. xs = [xt1s; xt0s] @@ -52,8 +52,8 @@ function plot_data() x2 = first.(xt0s) y2 = last.(xt0s) - plt = Plots.scatter(x1, y1; color="red", clim=(0, 1)) - Plots.scatter!(plt, x2, y2; color="blue", clim=(0, 1)) + plt = Plots.scatter(x1, y1; color = "red", clim = (0, 1)) + Plots.scatter!(plt, x2, y2; color = "blue", clim = (0, 1)) return plt end @@ -135,12 +135,11 @@ _, i = findmax(ch[:lp]) i = i.I[1] ## Plot the posterior distribution with a contour plot -x1_range = collect(range(-6; stop=6, length=25)) -x2_range = collect(range(-6; stop=6, length=25)) +x1_range = collect(range(-6; stop = 6, length = 25)) +x2_range = collect(range(-6; stop = 6, length = 25)) Z = [nn_forward([x1, x2], theta[i, :])[1] for x1 in x1_range, x2 in x2_range] contour!(x1_range, x2_range, Z) - # The contour plot above shows that the MAP method is not too bad at classifying our data. Now we can visualize our predictions. # $p(\tilde{x} | X, \alpha) = \int_{\theta} p(\tilde{x} | \theta) p(\theta | X, \alpha) \approx \sum_{\theta \sim p(\theta | X, \alpha)}f_{\theta}(\tilde{x})$ @@ -158,8 +157,8 @@ end plot_data() n_end = 1500 -x1_range = collect(range(-6; stop=6, length=25)) -x2_range = collect(range(-6; stop=6, length=25)) +x1_range = collect(range(-6; stop = 6, length = 25)) +x2_range = collect(range(-6; stop = 6, length = 25)) Z = [nn_predict([x1, x2], theta, n_end)[1] for x1 in x1_range, x2 in x2_range] contour!(x1_range, x2_range, Z) @@ -171,5 +170,5 @@ n_end = 1000 anim = @gif for i in 1:n_end plot_data() Z = [nn_forward([x1, x2], theta[i, :])[1] for x1 in x1_range, x2 in x2_range] - contour!(x1_range, x2_range, Z; title="Iteration $i", clim=(0, 1)) + contour!(x1_range, x2_range, Z; title = "Iteration $i", clim = (0, 1)) end every 5 diff --git a/examples/ImageNet/main.jl b/examples/ImageNet/main.jl index 5d70ac050..e11fdc50d 100644 --- a/examples/ImageNet/main.jl +++ b/examples/ImageNet/main.jl @@ -26,17 +26,17 @@ import DataLoaders: LearnBase # Extending Datasets import MLUtils # Distributed Training -FluxMPI.Init(;verbose=true) +FluxMPI.Init(; verbose = true) CUDA.allowscalar(false) # unsafe_free OneHotArrays CUDA.unsafe_free!(x::OneHotArray) = CUDA.unsafe_free!(x.indices) # Image Classification Models -VGG11_BN(args...; kwargs...) = VGG11(args...; batchnorm=true, kwargs...) -VGG13_BN(args...; kwargs...) = VGG13(args...; batchnorm=true, kwargs...) -VGG16_BN(args...; kwargs...) = VGG16(args...; batchnorm=true, kwargs...) -VGG19_BN(args...; kwargs...) = VGG19(args...; batchnorm=true, kwargs...) +VGG11_BN(args...; kwargs...) = VGG11(args...; batchnorm = true, kwargs...) +VGG13_BN(args...; kwargs...) = VGG13(args...; batchnorm = true, kwargs...) +VGG16_BN(args...; kwargs...) = VGG16(args...; batchnorm = true, kwargs...) +VGG19_BN(args...; kwargs...) = VGG19(args...; batchnorm = true, kwargs...) MobileNetv3_small(args...; kwargs...) = MobileNetv3(:small, args...; kwargs...) MobileNetv3_large(args...; kwargs...) = MobileNetv3(:large, args...; kwargs...) ResNeXt50(args...; kwargs...) = ResNeXt(50, args...; kwargs...) @@ -75,7 +75,8 @@ AVAILABLE_IMAGENET_MODELS = [ IMAGENET_MODELS_DICT = Dict(string(model) => model for model in AVAILABLE_IMAGENET_MODELS) -function get_model(model_name::String, models_dict::Dict, rng, args...; warmup=true, kwargs...) +function get_model(model_name::String, models_dict::Dict, rng, args...; warmup = true, + kwargs...) model = Lux.transform(models_dict[model_name](args...; kwargs...).layers) ps, st = Lux.setup(rng, model) .|> gpu if warmup @@ -85,14 +86,15 @@ function get_model(model_name::String, models_dict::Dict, rng, args...; warmup=t should_log() && println("$(now()) ==> staring `$model_name` warmup...") model(x__, ps, st) should_log() && println("$(now()) ==> forward pass warmup completed") - (l, _, _), back = Zygote.pullback(p -> logitcrossentropyloss(x__, y__, model, p, st), ps) + (l, _, _), back = Zygote.pullback(p -> logitcrossentropyloss(x__, y__, model, p, st), + ps) back((one(l), nothing, nothing)) should_log() && println("$(now()) ==> backward pass warmup completed") end if is_distributed() - ps = FluxMPI.synchronize!(ps; root_rank=0) - st = FluxMPI.synchronize!(st; root_rank=0) + ps = FluxMPI.synchronize!(ps; root_rank = 0) + st = FluxMPI.synchronize!(st; root_rank = 0) should_log() && println("$(now()) ==> models synced across all ranks") end @@ -104,61 +106,61 @@ function parse_commandline_arguments() parse_settings = ArgParseSettings("Lux ImageNet Training") @add_arg_table! parse_settings begin "--arch" - default = "ResNet18" - range_tester = x -> x ∈ keys(IMAGENET_MODELS_DICT) - help = "model architectures: " * join(keys(IMAGENET_MODELS_DICT), ", ", " or ") + default = "ResNet18" + range_tester = x -> x ∈ keys(IMAGENET_MODELS_DICT) + help = "model architectures: " * join(keys(IMAGENET_MODELS_DICT), ", ", " or ") "--epochs" - help = "number of total epochs to run" - arg_type = Int - default = 90 + help = "number of total epochs to run" + arg_type = Int + default = 90 "--start-epoch" - help = "manual epoch number (useful on restarts)" - arg_type = Int - default = 0 + help = "manual epoch number (useful on restarts)" + arg_type = Int + default = 0 "--batch-size" - help = "mini-batch size, this is the total batch size across all GPUs" - arg_type = Int - default = 256 + help = "mini-batch size, this is the total batch size across all GPUs" + arg_type = Int + default = 256 "--learning-rate" - help = "initial learning rate" - arg_type = Float32 - default = 0.1f0 + help = "initial learning rate" + arg_type = Float32 + default = 0.1f0 "--momentum" - help = "momentum" - arg_type = Float32 - default = 0.9f0 + help = "momentum" + arg_type = Float32 + default = 0.9f0 "--weight-decay" - help = "weight decay" - arg_type = Float32 - default = 1.0f-4 + help = "weight decay" + arg_type = Float32 + default = 1.0f-4 "--print-freq" - help = "print frequency" - arg_type = Int - default = 10 + help = "print frequency" + arg_type = Int + default = 10 "--resume" - help = "resume from checkpoint" - arg_type = String - default = "" + help = "resume from checkpoint" + arg_type = String + default = "" "--evaluate" - help = "evaluate model on validation set" - action = :store_true + help = "evaluate model on validation set" + action = :store_true "--pretrained" - help = "use pre-trained model" - action = :store_true + help = "use pre-trained model" + action = :store_true "--seed" - help = "seed for initializing training. " - arg_type = Int - default = 0 + help = "seed for initializing training. " + arg_type = Int + default = 0 "data" - help = "path to dataset" - required = true + help = "path to dataset" + required = true end return parse_args(parse_settings) end # Loss Function -logitcrossentropyloss(ŷ, y) = mean(-sum(y .* logsoftmax(ŷ; dims=1); dims=1)) +logitcrossentropyloss(ŷ, y) = mean(-sum(y .* logsoftmax(ŷ; dims = 1); dims = 1)) function logitcrossentropyloss(x, y, model, ps, st) ŷ, st_ = model(x, ps, st) @@ -179,16 +181,17 @@ end update_lr(st_opt::NamedTuple, eta) = fmap(l -> update_lr(l, eta), st_opt) # Accuracy -function accuracy(ŷ, y, topk=(1,)) +function accuracy(ŷ, y, topk = (1,)) maxk = maximum(topk) - pred_labels = partialsortperm.(eachcol(ŷ), (1:maxk,), rev=true) + pred_labels = partialsortperm.(eachcol(ŷ), (1:maxk,), rev = true) true_labels = onecold(y) accuracies = Vector{Float32}(undef, length(topk)) for (i, k) in enumerate(topk) - accuracies[i] = sum(map((a, b) -> sum(view(a, 1:k) .== b), pred_labels, true_labels)) + accuracies[i] = sum(map((a, b) -> sum(view(a, 1:k) .== b), pred_labels, + true_labels)) end return accuracies .* 100 ./ size(y, ndims(y)) @@ -199,28 +202,28 @@ is_distributed() = FluxMPI.Initialized() && total_workers() > 1 should_log() = !FluxMPI.Initialized() || local_rank() == 0 # Checkpointing -function save_checkpoint(state, is_best, filename="checkpoint.pth.tar") +function save_checkpoint(state, is_best, filename = "checkpoint.pth.tar") if should_log() serialize(filename, state) if is_best - cp(filename, "model_best.pth.tar"; force=true) + cp(filename, "model_best.pth.tar"; force = true) end end end # DataLoading struct ImageDataset - image_files - labels - mapping - augmentation_pipeline - normalization_parameters + image_files::Any + labels::Any + mapping::Any + augmentation_pipeline::Any + normalization_parameters::Any end function ImageDataset(folder::String, augmentation_pipeline, normalization_parameters) ulabels = readdir(folder) label_dirs = joinpath.((folder,), ulabels) - @assert length(label_dirs) == 1000 "There should be 1000 subdirectories in $folder" + @assert length(label_dirs)==1000 "There should be 1000 subdirectories in $folder" classes = readlines(joinpath(@__DIR__, "synsets.txt")) mapping = Dict(z => i for (i, z) in enumerate(ulabels)) @@ -228,7 +231,8 @@ function ImageDataset(folder::String, augmentation_pipeline, normalization_param istrain = endswith(folder, r"train|train/") if istrain - image_files = vcat(map((x, y) -> joinpath.((x,), y), label_dirs, readdir.(label_dirs))...) + image_files = vcat(map((x, y) -> joinpath.((x,), y), label_dirs, + readdir.(label_dirs))...) remove_files = [ "n01739381_1309.JPEG", @@ -253,15 +257,15 @@ function ImageDataset(folder::String, augmentation_pipeline, normalization_param "n04596742_4225.JPEG", "n07583066_647.JPEG", "n13037406_4650.JPEG", - "n02105855_2933.JPEG" + "n02105855_2933.JPEG", ] - remove_files = joinpath.( - (folder,), joinpath.(first.(rsplit.(remove_files, "_", limit=2)), remove_files) - ) - + remove_files = joinpath.((folder,), + joinpath.(first.(rsplit.(remove_files, "_", limit = 2)), + remove_files)) + image_files = [setdiff(Set(image_files), Set(remove_files))...] - labels = [mapping[x] for x in map(x -> x[2], rsplit.(image_files, "/", limit=3))] + labels = [mapping[x] for x in map(x -> x[2], rsplit.(image_files, "/", limit = 3))] else vallist = hcat(split.(readlines(joinpath(@__DIR__, "val_list.txt")))...) labels = parse.(Int, vallist[2, :]) .+ 1 @@ -272,7 +276,8 @@ function ImageDataset(folder::String, augmentation_pipeline, normalization_param labels = labels[idxs] end - return ImageDataset(image_files, labels, mapping, augmentation_pipeline, normalization_parameters) + return ImageDataset(image_files, labels, mapping, augmentation_pipeline, + normalization_parameters) end LearnBase.nobs(data::ImageDataset) = length(data.image_files) @@ -301,7 +306,7 @@ LearnBase.getobs(data::DistributedDataContainer, i::Int) = MLUtils.getobs(data, # Tracking Base.@kwdef mutable struct AverageMeter - fmtstr + fmtstr::Any val::Float64 = 0.0 sum::Float64 = 0.0 count::Int = 0 @@ -310,7 +315,7 @@ end function AverageMeter(name::String, fmt::String) fmtstr = FormatExpr("$name {1:$fmt} ({2:$fmt})") - return AverageMeter(; fmtstr=fmtstr) + return AverageMeter(; fmtstr = fmtstr) end function update!(meter::AverageMeter, val, n::Int) @@ -324,11 +329,11 @@ end print_meter(meter::AverageMeter) = printfmt(meter.fmtstr, meter.val, meter.average) struct ProgressMeter{N} - batch_fmtstr - meters::NTuple{N,AverageMeter} + batch_fmtstr::Any + meters::NTuple{N, AverageMeter} end -function ProgressMeter(num_batches::Int, meters::NTuple{N}, prefix::String="") where {N} +function ProgressMeter(num_batches::Int, meters::NTuple{N}, prefix::String = "") where {N} fmt = "%" * string(length(string(num_batches))) * "d" prefix = prefix != "" ? endswith(prefix, " ") ? prefix : prefix * " " : "" batch_fmtstr = generate_formatter("$prefix[$fmt/" * sprintf1(fmt, num_batches) * "]") @@ -351,7 +356,9 @@ function validate(val_loader, model, ps, st, args) top1 = AverageMeter("Acc@1", "6.2f") top5 = AverageMeter("Acc@5", "6.2f") - progress = ProgressMeter(length(val_loader), (batch_time, data_time, forward_time, losses, top1, top5), "Val:") + progress = ProgressMeter(length(val_loader), + (batch_time, data_time, forward_time, losses, top1, top5), + "Val:") st_ = Lux.testmode(st) t = time() @@ -395,7 +402,9 @@ function train(train_loader, model, ps, st, optimiser_state, epoch, args) losses = AverageMeter("Loss", ".4e") top1 = AverageMeter("Acc@1", "6.2f") top5 = AverageMeter("Acc@5", "6.2f") - progress = ProgressMeter(length(train_loader), (batch_time, data_time, forward_time, backward_time, optimize_time, losses, top1, top5), "Epoch: [$epoch]") + progress = ProgressMeter(length(train_loader), + (batch_time, data_time, forward_time, backward_time, + optimize_time, losses, top1, top5), "Epoch: [$epoch]") st = Lux.trainmode(st) @@ -404,7 +413,8 @@ function train(train_loader, model, ps, st, optimiser_state, epoch, args) t_data, t = time() - t, time() # Gradients and Update - (loss, ŷ, st), back = Zygote.pullback(p -> logitcrossentropyloss(x, y, model, p, st), ps) + (loss, ŷ, st), back = Zygote.pullback(p -> logitcrossentropyloss(x, y, model, p, + st), ps) t_forward, t = time() - t, time() gs = back((one(loss) / total_workers(), nothing, nothing))[1] t_backward, t = time() - t, time() @@ -454,44 +464,39 @@ function main(args) println("$(now()) => creating model `$(args["arch"])`") end end - model, ps, st = get_model(args["arch"], IMAGENET_MODELS_DICT, rng; warmup=true, pretrain=args["pretrained"]) + model, ps, st = get_model(args["arch"], IMAGENET_MODELS_DICT, rng; warmup = true, + pretrain = args["pretrained"]) - normalization_parameters = ( - mean=reshape([0.485f0, 0.456f0, 0.406f0], 1, 1, 3), - std=reshape([0.229f0, 0.224f0, 0.225f0], 1, 1, 3) - ) + normalization_parameters = (mean = reshape([0.485f0, 0.456f0, 0.406f0], 1, 1, 3), + std = reshape([0.229f0, 0.224f0, 0.225f0], 1, 1, 3)) train_data_augmentation = Resize(256, 256) |> FlipX(0.5) |> RCropSize(224, 224) val_data_augmentation = Resize(256, 256) |> CropSize(224, 224) - train_dataset = ImageDataset( - joinpath(args["data"], "train"), - train_data_augmentation, - normalization_parameters - ) - val_dataset = ImageDataset( - joinpath(args["data"], "val"), - val_data_augmentation, - normalization_parameters - ) + train_dataset = ImageDataset(joinpath(args["data"], "train"), + train_data_augmentation, + normalization_parameters) + val_dataset = ImageDataset(joinpath(args["data"], "val"), + val_data_augmentation, + normalization_parameters) if is_distributed() train_dataset = DistributedDataContainer(train_dataset) val_dataset = DistributedDataContainer(val_dataset) end - train_loader = DataLoader(shuffleobs(train_dataset), args["batch-size"] ÷ total_workers()) + train_loader = DataLoader(shuffleobs(train_dataset), + args["batch-size"] ÷ total_workers()) val_loader = DataLoader(val_dataset, args["batch-size"] ÷ total_workers()) # Optimizer and Scheduler should_log() && println("$(now()) => creating optimiser") - optimiser = Optimisers.OptimiserChain( - Optimisers.Momentum(args["learning-rate"], args["momentum"]), - Optimisers.WeightDecay(args["weight-decay"]) - ) + optimiser = Optimisers.OptimiserChain(Optimisers.Momentum(args["learning-rate"], + args["momentum"]), + Optimisers.WeightDecay(args["weight-decay"])) optimiser_state = Optimisers.setup(optimiser, ps) if is_distributed() optimiser_state = FluxMPI.synchronize!(optimiser_state) should_log() && println("$(now()) ==> synced optimiser state across all ranks") end - scheduler = Step(λ=args["learning-rate"], γ=0.1f0, step_sizes=30) + scheduler = Step(λ = args["learning-rate"], γ = 0.1f0, step_sizes = 30) if args["resume"] != "" if isfile(args["resume"]) @@ -500,9 +505,11 @@ function main(args) optimiser_state = checkpoint["optimiser_state"] |> gpu ps = checkpoint["model_parameters"] |> gpu st = checkpoint["model_states"] |> gpu - should_log() && println("$(now()) => loaded checkpoint `$(args["resume"])` (epoch $(args["start-epoch"]))") + should_log() && + println("$(now()) => loaded checkpoint `$(args["resume"])` (epoch $(args["start-epoch"]))") else - should_log() && println("$(now()) => no checkpoint found at `$(args["resume"])`") + should_log() && + println("$(now()) => no checkpoint found at `$(args["resume"])`") end end @@ -517,7 +524,8 @@ function main(args) for epoch in args["start-epoch"]:args["epochs"] # Train for 1 epoch - ps, st, optimiser_state, _ = train(train_loader, model, ps, st, optimiser_state, epoch, args) + ps, st, optimiser_state, _ = train(train_loader, model, ps, st, optimiser_state, + epoch, args) # Some Housekeeping GC.gc(true) @@ -538,13 +546,11 @@ function main(args) is_best = acc1 > best_acc1 best_acc1 = max(acc1, best_acc1) - save_state = Dict( - "epoch" => epoch, - "arch" => args["arch"], - "model_states" => st |> cpu, - "model_parameters" => ps |> cpu, - "optimiser_state" => optimiser_state |> cpu, - ) + save_state = Dict("epoch" => epoch, + "arch" => args["arch"], + "model_states" => st |> cpu, + "model_parameters" => ps |> cpu, + "optimiser_state" => optimiser_state |> cpu) save_checkpoint(save_state, is_best) end end diff --git a/examples/NeuralODE/main.jl b/examples/NeuralODE/main.jl index 5b74888ed..bb87db840 100644 --- a/examples/NeuralODE/main.jl +++ b/examples/NeuralODE/main.jl @@ -4,7 +4,8 @@ using Lux using Pkg #hide Pkg.activate(joinpath(dirname(pathof(Lux)), "..", "examples")) #hide -using ComponentArrays, CUDA, DiffEqSensitivity, NNlib, Optimisers, OrdinaryDiffEq, Random, Statistics, Zygote +using ComponentArrays, CUDA, DiffEqSensitivity, NNlib, Optimisers, OrdinaryDiffEq, Random, + Statistics, Zygote import MLDatasets: MNIST import MLDataUtils: convertlabel, LabelEnc import MLUtils: DataLoader, splitobs @@ -12,7 +13,9 @@ CUDA.allowscalar(false) # ## Loading MNIST ## Use MLDataUtils LabelEnc for natural onehot conversion -onehot(labels_raw) = convertlabel(LabelEnc.OneOfK, labels_raw, LabelEnc.NativeLabels(collect(0:9))) +function onehot(labels_raw) + convertlabel(LabelEnc.OneOfK, labels_raw, LabelEnc.NativeLabels(collect(0:9))) +end function loadmnist(batchsize, train_split) ## Load MNIST: Only 1500 for demonstration purposes @@ -23,21 +26,21 @@ function loadmnist(batchsize, train_split) ## Process images into (H,W,C,BS) batches x_data = Float32.(reshape(imgs, size(imgs, 1), size(imgs, 2), 1, size(imgs, 3))) y_data = onehot(labels_raw) - (x_train, y_train), (x_test, y_test) = splitobs((x_data, y_data); at=train_split) + (x_train, y_train), (x_test, y_test) = splitobs((x_data, y_data); at = train_split) return ( - ## Use DataLoader to automatically minibatch and shuffle the data - DataLoader(collect.((x_train, y_train)); batchsize=batchsize, shuffle=true), - ## Don't shuffle the test data - DataLoader(collect.((x_test, y_test)); batchsize=batchsize, shuffle=false), - ) + ## Use DataLoader to automatically minibatch and shuffle the data + DataLoader(collect.((x_train, y_train)); batchsize = batchsize, shuffle = true), + ## Don't shuffle the test data + DataLoader(collect.((x_test, y_test)); batchsize = batchsize, shuffle = false)) end # ## Define the Neural ODE Layer # # The NeuralODE is a ContainerLayer. It stores a `model` and the parameters and states of the NeuralODE is # same as that of the underlying model. -struct NeuralODE{M<:Lux.AbstractExplicitLayer,So,Se,T,K} <: Lux.AbstractExplicitContainerLayer{(:model,)} +struct NeuralODE{M <: Lux.AbstractExplicitLayer, So, Se, T, K} <: + Lux.AbstractExplicitContainerLayer{(:model,)} model::M solver::So sensealg::Se @@ -45,13 +48,11 @@ struct NeuralODE{M<:Lux.AbstractExplicitLayer,So,Se,T,K} <: Lux.AbstractExplicit kwargs::K end -function NeuralODE( - model::Lux.AbstractExplicitLayer; - solver=Tsit5(), - sensealg=InterpolatingAdjoint(; autojacvec=ZygoteVJP()), - tspan=(0.0f0, 1.0f0), - kwargs..., -) +function NeuralODE(model::Lux.AbstractExplicitLayer; + solver = Tsit5(), + sensealg = InterpolatingAdjoint(; autojacvec = ZygoteVJP()), + tspan = (0.0f0, 1.0f0), + kwargs...) return NeuralODE(model, solver, sensealg, tspan, kwargs) end @@ -61,28 +62,27 @@ function (n::NeuralODE)(x, ps, st) return u_ end prob = ODEProblem{false}(ODEFunction{false}(dudt), x, n.tspan, ps) - return solve(prob, n.solver; sensealg=n.sensealg, n.kwargs...), st + return solve(prob, n.solver; sensealg = n.sensealg, n.kwargs...), st end -diffeqsol_to_array(x::ODESolution{T,N,<:AbstractVector{<:CuArray}}) where {T,N} = dropdims(gpu(x); dims=3) -diffeqsol_to_array(x::ODESolution) = dropdims(Array(x); dims=3) +function diffeqsol_to_array(x::ODESolution{T, N, <:AbstractVector{<:CuArray}}) where {T, N} + dropdims(gpu(x); dims = 3) +end +diffeqsol_to_array(x::ODESolution) = dropdims(Array(x); dims = 3) # ## Create and Initialize the Neural ODE Layer function create_model() ## Construct the Neural ODE Model - model = Chain( - FlattenLayer(), - Dense(784, 20, tanh), - NeuralODE( - Chain(Dense(20, 10, tanh), Dense(10, 10, tanh), Dense(10, 20, tanh)); - save_everystep=false, - reltol=1.0f-3, - abstol=1.0f-3, - save_start=false, - ), - diffeqsol_to_array, - Dense(20, 10), - ) + model = Chain(FlattenLayer(), + Dense(784, 20, tanh), + NeuralODE(Chain(Dense(20, 10, tanh), Dense(10, 10, tanh), + Dense(10, 20, tanh)); + save_everystep = false, + reltol = 1.0f-3, + abstol = 1.0f-3, + save_start = false), + diffeqsol_to_array, + Dense(20, 10)) rng = Random.default_rng() Random.seed!(rng, 0) @@ -128,7 +128,8 @@ function train() st_opt = Optimisers.setup(opt, ps) ### Warmup the Model - img, lab = gpu(train_dataloader.data[1][:, :, :, 1:1]), gpu(train_dataloader.data[2][:, 1:1]) + img, lab = gpu(train_dataloader.data[1][:, :, :, 1:1]), + gpu(train_dataloader.data[2][:, 1:1]) loss(img, lab, model, ps, st) (l, _), back = pullback(p -> loss(img, lab, model, p, st), ps) back((one(l), nothing)) @@ -145,11 +146,9 @@ function train() end ttime = time() - stime - println( - "[$epoch/$nepochs] \t Time $(round(ttime; digits=2))s \t Training Accuracy: " * - "$(round(accuracy(model, ps, st, train_dataloader) * 100; digits=2))% \t " * - "Test Accuracy: $(round(accuracy(model, ps, st, test_dataloader) * 100; digits=2))%" - ) + println("[$epoch/$nepochs] \t Time $(round(ttime; digits=2))s \t Training Accuracy: " * + "$(round(accuracy(model, ps, st, train_dataloader) * 100; digits=2))% \t " * + "Test Accuracy: $(round(accuracy(model, ps, st, test_dataloader) * 100; digits=2))%") end end diff --git a/examples/SimpleRNN/main.jl b/examples/SimpleRNN/main.jl index 651408a50..d10a211c9 100644 --- a/examples/SimpleRNN/main.jl +++ b/examples/SimpleRNN/main.jl @@ -16,25 +16,25 @@ using MLUtils, Optimisers, Zygote, NNlib, Random, Statistics # We will use MLUtils to generate 500 (noisy) clockwise and 500 (noisy) anticlockwise spirals. Using this data we will create a `MLUtils.DataLoader`. Our dataloader will give us sequences of size 2 × seq_len × batch_size and we need to predict a binary value whether the sequence is clockwise or anticlockwise -function get_dataloaders(; dataset_size=1000, sequence_length=50) +function get_dataloaders(; dataset_size = 1000, sequence_length = 50) ## Create the spirals data = [MLUtils.Datasets.make_spiral(sequence_length) for _ in 1:dataset_size] ## Get the labels labels = vcat(repeat([0.0f0], dataset_size ÷ 2), repeat([1.0f0], dataset_size ÷ 2)) - clockwise_spirals = [reshape(d[1][:, 1:sequence_length], :, sequence_length, 1) for d in data[1:(dataset_size ÷ 2)]] - anticlockwise_spirals = [ - reshape(d[1][:, (sequence_length + 1):end], :, sequence_length, 1) for d in data[((dataset_size ÷ 2) + 1):end] - ] - x_data = Float32.(cat(clockwise_spirals..., anticlockwise_spirals...; dims=3)) + clockwise_spirals = [reshape(d[1][:, 1:sequence_length], :, sequence_length, 1) + for d in data[1:(dataset_size ÷ 2)]] + anticlockwise_spirals = [reshape(d[1][:, (sequence_length + 1):end], :, sequence_length, + 1) for d in data[((dataset_size ÷ 2) + 1):end]] + x_data = Float32.(cat(clockwise_spirals..., anticlockwise_spirals...; dims = 3)) ## Split the dataset - (x_train, y_train), (x_val, y_val) = splitobs((x_data, labels); at=0.8, shuffle=true) + (x_train, y_train), (x_val, y_val) = splitobs((x_data, labels); at = 0.8, + shuffle = true) ## Create DataLoaders return ( - ## Use DataLoader to automatically minibatch and shuffle the data - DataLoader(collect.((x_train, y_train)); batchsize=128, shuffle=true), - ## Don't shuffle the validation data - DataLoader(collect.((x_val, y_val)); batchsize=128, shuffle=false), - ) + ## Use DataLoader to automatically minibatch and shuffle the data + DataLoader(collect.((x_train, y_train)); batchsize = 128, shuffle = true), + ## Don't shuffle the validation data + DataLoader(collect.((x_val, y_val)); batchsize = 128, shuffle = false)) end # ## Creating a Classifier @@ -43,7 +43,8 @@ end # We pass the fieldnames `lstm_cell` and `classifier` to the type to ensure that the parameters and states are automatically populated and we don't have to define [`Lux.initialparameters`](@ref) and [`Lux.initialstates`](@ref). -struct SpiralClassifier{L,C} <: Lux.AbstractExplicitContainerLayer{(:lstm_cell, :classifier)} +struct SpiralClassifier{L, C} <: + Lux.AbstractExplicitContainerLayer{(:lstm_cell, :classifier)} lstm_cell::L classifier::C end @@ -51,12 +52,14 @@ end # We won't define the model from scratch but rather use the [`Lux.LSTMCell`](@ref) and [`Lux.Dense`](@ref) function SpiralClassifier(in_dims, hidden_dims, out_dims) - return SpiralClassifier(LSTMCell(in_dims => hidden_dims), Dense(hidden_dims => out_dims, sigmoid)) + return SpiralClassifier(LSTMCell(in_dims => hidden_dims), + Dense(hidden_dims => out_dims, sigmoid)) end # Now we need to define the behavior of the Classifier when it is invoked -function (s::SpiralClassifier)(x::AbstractArray{T,3}, ps::NamedTuple, st::NamedTuple) where {T} +function (s::SpiralClassifier)(x::AbstractArray{T, 3}, ps::NamedTuple, + st::NamedTuple) where {T} ## First we will have to run the sequence through the LSTM Cell ## The first call to LSTM Cell will create the initial hidden state ## See that the parameters and states are automatically populated into a field called `lstm_cell` @@ -69,7 +72,7 @@ function (s::SpiralClassifier)(x::AbstractArray{T,3}, ps::NamedTuple, st::NamedT ## After running through the sequence we will pass the output through the classifier y, st_classifier = s.classifier(h, ps.classifier, st.classifier) ## Finally remember to create the updated state - st = merge(st, (classifier=st_classifier, lstm_cell=st_lstm)) + st = merge(st, (classifier = st_classifier, lstm_cell = st_lstm)) return vec(y), st end diff --git a/src/Lux.jl b/src/Lux.jl index 3b8395fde..abc02b761 100644 --- a/src/Lux.jl +++ b/src/Lux.jl @@ -20,7 +20,7 @@ using Optimisers # Optional Dependency using Requires -const use_cuda = Ref{Union{Nothing,Bool}}(nothing) +const use_cuda = Ref{Union{Nothing, Bool}}(nothing) # Data Transfer Utilities include("adapt.jl") @@ -50,7 +50,8 @@ export cpu, gpu # Layers export Chain, Parallel, SkipConnection, PairwiseFusion, BranchLayer export Dense, Scale -export Conv, MaxPool, MeanPool, GlobalMaxPool, GlobalMeanPool, AdaptiveMaxPool, AdaptiveMeanPool, Upsample +export Conv, MaxPool, MeanPool, GlobalMaxPool, GlobalMeanPool, AdaptiveMaxPool, + AdaptiveMeanPool, Upsample export Dropout, VariationalHiddenDropout export BatchNorm, GroupNorm export WeightNorm diff --git a/src/adapt.jl b/src/adapt.jl index 9a98356bb..45b7d3312 100644 --- a/src/adapt.jl +++ b/src/adapt.jl @@ -6,20 +6,26 @@ struct LuxCUDAAdaptor <: LuxDeviceAdaptor end adapt_storage(::LuxCUDAAdaptor, x) = CUDA.cu(x) adapt_storage(::LuxCUDAAdaptor, x::FillArrays.AbstractFill) = CUDA.cu(colelct(x)) adapt_storage(::LuxCUDAAdaptor, x::Zygote.OneElement) = CUDA.cu(collect(x)) -adapt_storage(to::LuxCUDAAdaptor, x::ComponentArray) = ComponentArray(adapt_storage(to, getdata(x)), getaxes(x)) +function adapt_storage(to::LuxCUDAAdaptor, x::ComponentArray) + ComponentArray(adapt_storage(to, getdata(x)), getaxes(x)) +end adapt_storage(::LuxCUDAAdaptor, rng::AbstractRNG) = rng -function adapt_storage( - ::LuxCPUAdaptor, - x::Union{AbstractRange,FillArrays.AbstractFill,Zygote.OneElement,SparseArrays.AbstractSparseArray}, -) +function adapt_storage(::LuxCPUAdaptor, + x::Union{AbstractRange, FillArrays.AbstractFill, Zygote.OneElement, + SparseArrays.AbstractSparseArray}) return x end adapt_storage(::LuxCPUAdaptor, x::AbstractArray) = adapt(Array, x) -adapt_storage(to::LuxCPUAdaptor, x::ComponentArray) = ComponentArray(adapt_storage(to, getdata(x)), getaxes(x)) +function adapt_storage(to::LuxCPUAdaptor, x::ComponentArray) + ComponentArray(adapt_storage(to, getdata(x)), getaxes(x)) +end adapt_storage(::LuxCPUAdaptor, rng::AbstractRNG) = rng # TODO: SparseArrays -adapt_storage(::LuxCPUAdaptor, x::CUDA.CUSPARSE.CUDA.CUSPARSE.AbstractCuSparseMatrix) = adapt(Array, x) +function adapt_storage(::LuxCPUAdaptor, + x::CUDA.CUSPARSE.CUDA.CUSPARSE.AbstractCuSparseMatrix) + adapt(Array, x) +end _isbitsarray(::AbstractArray{<:Number}) = true _isbitsarray(::AbstractArray{T}) where {T} = isbitstype(T) @@ -42,7 +48,7 @@ Transfer `x` to GPU """ function gpu(x) check_use_cuda() - return use_cuda[] ? fmap(x -> adapt(LuxCUDAAdaptor(), x), x; exclude=_isleaf) : x + return use_cuda[] ? fmap(x -> adapt(LuxCUDAAdaptor(), x), x; exclude = _isleaf) : x end function check_use_cuda() @@ -53,7 +59,7 @@ function check_use_cuda() end if !(use_cuda[]) @info """The GPU function is being called but the GPU is not accessible. - Defaulting back to the CPU. (No action is required if you want to run on the CPU).""" maxlog = 1 + Defaulting back to the CPU. (No action is required if you want to run on the CPU).""" maxlog=1 end end end diff --git a/src/autodiff.jl b/src/autodiff.jl index 6e9a3da70..0773df613 100644 --- a/src/autodiff.jl +++ b/src/autodiff.jl @@ -1,6 +1,7 @@ # Non Differentiable Functions ChainRulesCore.@non_differentiable replicate(::Any) -ChainRulesCore.@non_differentiable update_statistics(::Any, ::Any, ::Any, ::Any, ::Any, ::Any, ::Any) +ChainRulesCore.@non_differentiable update_statistics(::Any, ::Any, ::Any, ::Any, ::Any, + ::Any, ::Any) ChainRulesCore.@non_differentiable generate_dropout_mask(::Any, ::Any, ::Any) ChainRulesCore.@non_differentiable compute_adaptive_pooling_dims(::Any, ::Any) ChainRulesCore.@non_differentiable glorot_normal(::Any...) @@ -8,23 +9,23 @@ ChainRulesCore.@non_differentiable glorot_uniform(::Any...) ChainRulesCore.@non_differentiable check_use_cuda() ChainRulesCore.@non_differentiable istraining(::Any) -ChainRulesCore.Tangent{P}(; kwargs...) where {P<:AbstractExplicitLayer} = NoTangent() +ChainRulesCore.Tangent{P}(; kwargs...) where {P <: AbstractExplicitLayer} = NoTangent() ChainRulesCore.rrule(::typeof(istraining)) = true, _ -> (NoTangent(),) -ChainRulesCore.rrule(::typeof(Base.broadcasted), ::typeof(identity), x) = x, Δ -> (NoTangent(), NoTangent(), Δ) +function ChainRulesCore.rrule(::typeof(Base.broadcasted), ::typeof(identity), x) + x, Δ -> (NoTangent(), NoTangent(), Δ) +end # NNlib Functions -function ChainRulesCore.rrule( - ::typeof(batchnorm), - g::CuArray{T}, - b::CuArray{T}, - x::Union{CuArray{T,4},CuArray{T,5}}, - running_mean, - running_var, - momentum; - kwargs..., -) where {T<:CUDNNFloat} +function ChainRulesCore.rrule(::typeof(batchnorm), + g::CuArray{T}, + b::CuArray{T}, + x::Union{CuArray{T, 4}, CuArray{T, 5}}, + running_mean, + running_var, + momentum; + kwargs...) where {T <: CUDNNFloat} y = batchnorm(g, b, x, running_mean, running_var, momentum; kwargs...) function batchnorm_pullback(dy) dg, db, dx = ∇batchnorm(g, b, x, dy, running_mean, running_var, momentum; kwargs...) @@ -34,9 +35,8 @@ function ChainRulesCore.rrule( end # Activation Rrules -function ChainRulesCore.rrule( - ::typeof(applyactivation), f::cudnnValidActivationTypes, x::CuArray{T} -) where {T<:CUDNNFloat} +function ChainRulesCore.rrule(::typeof(applyactivation), f::cudnnValidActivationTypes, + x::CuArray{T}) where {T <: CUDNNFloat} mode = getCUDNNActivationMode(f) y = CUDNN.cudnnActivationForward(x; mode) function applyactivation_pullback(Δ) @@ -68,12 +68,15 @@ function ChainRulesCore.rrule(::typeof(Array), x::CUDA.CuArray) return Array(x), d -> (NoTangent(), CUDA.cu(d)) end -function ChainRulesCore.rrule(::typeof(adapt_storage), to::LuxCPUAdaptor, x::CUDA.AbstractGPUArray) - return adapt_storage(to, x), d -> (NoTangent(), NoTangent(), adapt_storage(LuxCUDAAdaptor(), d)) +function ChainRulesCore.rrule(::typeof(adapt_storage), to::LuxCPUAdaptor, + x::CUDA.AbstractGPUArray) + return adapt_storage(to, x), + d -> (NoTangent(), NoTangent(), adapt_storage(LuxCUDAAdaptor(), d)) end function ChainRulesCore.rrule(::typeof(adapt_storage), to::LuxCUDAAdaptor, x::Array) - return adapt_storage(to, x), d -> (NoTangent(), NoTangent(), adapt_storage(LuxCPUAdaptor(), d)) + return adapt_storage(to, x), + d -> (NoTangent(), NoTangent(), adapt_storage(LuxCPUAdaptor(), d)) end # RNN Helpers diff --git a/src/core.jl b/src/core.jl index ef5bc7ddb..e3d6f3678 100644 --- a/src/core.jl +++ b/src/core.jl @@ -14,7 +14,9 @@ abstract type AbstractExplicitLayer end Generate the initial parameters of the layer `l`. """ initialparameters(::AbstractRNG, ::Any) = NamedTuple() -initialparameters(rng::AbstractRNG, l::NamedTuple) = map(Base.Fix1(initialparameters, rng), l) +function initialparameters(rng::AbstractRNG, l::NamedTuple) + map(Base.Fix1(initialparameters, rng), l) +end """ initialstates(rng::AbstractRNG, l) @@ -29,21 +31,24 @@ initialstates(rng::AbstractRNG, l::NamedTuple) = map(Base.Fix1(initialstates, rn Return the total number of parameters of the layer `l`. """ -parameterlength(l::AbstractExplicitLayer) = parameterlength(initialparameters(Random.default_rng(), l)) -parameterlength(nt::Union{NamedTuple,Tuple}) = length(nt) == 0 ? 0 : sum(parameterlength, nt) +function parameterlength(l::AbstractExplicitLayer) + parameterlength(initialparameters(Random.default_rng(), l)) +end +function parameterlength(nt::Union{NamedTuple, Tuple}) + length(nt) == 0 ? 0 : sum(parameterlength, nt) +end parameterlength(a::AbstractArray) = length(a) parameterlength(x) = 0 - """ statelength(l) Return the total number of states of the layer `l`. """ statelength(l::AbstractExplicitLayer) = statelength(initialstates(Random.default_rng(), l)) -statelength(nt::Union{NamedTuple,Tuple}) = length(nt) == 0 ? 0 : sum(statelength, nt) +statelength(nt::Union{NamedTuple, Tuple}) = length(nt) == 0 ? 0 : sum(statelength, nt) statelength(a::AbstractArray) = length(a) -statelength(x::Union{Number,Symbol}) = 1 +statelength(x::Union{Number, Symbol}) = 1 statelength(x) = 0 """ @@ -51,17 +56,22 @@ statelength(x) = 0 Shorthand for getting the parameters and states of the layer `l`. Is equivalent to `(initialparameters(rng, l), initialstates(rng, l))`. """ -setup(rng::AbstractRNG, l::AbstractExplicitLayer) = (initialparameters(rng, l), initialstates(rng, l)) +function setup(rng::AbstractRNG, l::AbstractExplicitLayer) + (initialparameters(rng, l), initialstates(rng, l)) +end """ apply(model::AbstractExplicitLayer, x, ps::Union{ComponentArray,NamedTuple}, st::NamedTuple) Simply calls `model(x, ps, st)` """ -apply(model::AbstractExplicitLayer, x, ps::Union{ComponentArray,NamedTuple}, st::NamedTuple) = model(x, ps, st) +function apply(model::AbstractExplicitLayer, x, ps::Union{ComponentArray, NamedTuple}, + st::NamedTuple) + model(x, ps, st) +end function Base.show(io::IO, x::AbstractExplicitLayer) - __t = rsplit(string(get_typename(x)), "."; limit=2) + __t = rsplit(string(get_typename(x)), "."; limit = 2) T = length(__t) == 2 ? __t[2] : __t[1] print(io, "$T()") end @@ -69,21 +79,25 @@ end # Abstract Container Layers abstract type AbstractExplicitContainerLayer{layers} <: AbstractExplicitLayer end -function initialparameters(rng::AbstractRNG, l::AbstractExplicitContainerLayer{layers}) where {layers} +function initialparameters(rng::AbstractRNG, + l::AbstractExplicitContainerLayer{layers}) where {layers} length(layers) == 1 && return initialparameters(rng, getfield(l, layers[1])) return NamedTuple{layers}(initialparameters.(rng, getfield.((l,), layers))) end -function initialstates(rng::AbstractRNG, l::AbstractExplicitContainerLayer{layers}) where {layers} +function initialstates(rng::AbstractRNG, + l::AbstractExplicitContainerLayer{layers}) where {layers} length(layers) == 1 && return initialstates(rng, getfield(l, layers[1])) return NamedTuple{layers}(initialstates.(rng, getfield.((l,), layers))) end -parameterlength(l::AbstractExplicitContainerLayer{layers}) where {layers} = +function parameterlength(l::AbstractExplicitContainerLayer{layers}) where {layers} sum(parameterlength, getfield.((l,), layers)) +end -statelength(l::AbstractExplicitContainerLayer{layers}) where {layers} = +function statelength(l::AbstractExplicitContainerLayer{layers}) where {layers} sum(statelength, getfield.((l,), layers)) +end # Test Mode """ @@ -91,25 +105,26 @@ statelength(l::AbstractExplicitContainerLayer{layers}) where {layers} = Make all occurances of `training` in state `st` `!mode` """ -testmode(st::NamedTuple, mode::Bool=true) = update_state(st, :training, Val(!mode)) +testmode(st::NamedTuple, mode::Bool = true) = update_state(st, :training, Val(!mode)) """ trainmode(x::Any, mode::Bool=true) Make all occurances of `training` in state `st` `mode` """ -trainmode(x::Any, mode::Bool=true) = testmode(x, !mode) +trainmode(x::Any, mode::Bool = true) = testmode(x, !mode) """ update_state(st::NamedTuple, key::Symbol, value; layer_check=_default_layer_check(key)) Recursively update all occurances of the `key` in the state `st` with the `value`. """ -function update_state(st::NamedTuple, key::Symbol, value; layer_check=_default_layer_check(key)) +function update_state(st::NamedTuple, key::Symbol, value; + layer_check = _default_layer_check(key)) function _update_state(st, key::Symbol, value) return Setfield.set(st, Setfield.PropertyLens{key}(), value) end - return fmap(_st -> _update_state(_st, key, value), st; exclude=layer_check) + return fmap(_st -> _update_state(_st, key, value), st; exclude = layer_check) end function _default_layer_check(key) diff --git a/src/layers/basic.jl b/src/layers/basic.jl index f279ae407..13ab4423a 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -17,14 +17,16 @@ Reshapes the passed array to have a size of `(dims..., :)` * Empty `NamedTuple()` """ struct ReshapeLayer{N} <: AbstractExplicitLayer - dims::NTuple{N,Int} + dims::NTuple{N, Int} end @inline function (r::ReshapeLayer)(x::AbstractArray, ps, st::NamedTuple) return reshape(x, r.dims..., size(x, ndims(x))), st end -Base.show(io::IO, r::ReshapeLayer) = print(io, "ReshapeLayer(output_dims = (", join(r.dims, ", "), ", :))") +function Base.show(io::IO, r::ReshapeLayer) + print(io, "ReshapeLayer(output_dims = (", join(r.dims, ", "), ", :))") +end """ FlattenLayer() @@ -42,7 +44,7 @@ Flattens the passed array into a matrix. """ struct FlattenLayer <: AbstractExplicitLayer end -@inline function (f::FlattenLayer)(x::AbstractArray{T,N}, ps, st::NamedTuple) where {T,N} +@inline function (f::FlattenLayer)(x::AbstractArray{T, N}, ps, st::NamedTuple) where {T, N} return reshape(x, :, size(x, N)), st end @@ -72,7 +74,9 @@ end @inline (s::SelectDim)(x, ps, st::NamedTuple) = selectdim(x, s.dim, s.i), st -Base.show(io::IO, s::SelectDim) = print(io, "SelectDim(dim = ", s.dim, ", index = ", s.i, ")") +function Base.show(io::IO, s::SelectDim) + print(io, "SelectDim(dim = ", s.dim, ", index = ", s.i, ")") +end """ NoOpLayer() @@ -170,12 +174,14 @@ The simplest "ResNet"-type connection is just `SkipConnection(layer, +)`. See [`Parallel`](@ref) for a more general implementation. """ -struct SkipConnection{T<:AbstractExplicitLayer,F} <: AbstractExplicitContainerLayer{(:layers,)} +struct SkipConnection{T <: AbstractExplicitLayer, F} <: + AbstractExplicitContainerLayer{(:layers,)} layers::T connection::F end -@inline function (skip::SkipConnection)(x, ps::Union{ComponentArray,NamedTuple}, st::NamedTuple) +@inline function (skip::SkipConnection)(x, ps::Union{ComponentArray, NamedTuple}, + st::NamedTuple) mx, st = skip.layers(x, ps, st) return skip.connection(mx, x), st end @@ -209,7 +215,7 @@ Create a layer which passes an input to each path in `layers`, before reducing t See also [`SkipConnection`](@ref) which is `Parallel` with one identity. """ -struct Parallel{F,T<:NamedTuple} <: AbstractExplicitContainerLayer{(:layers,)} +struct Parallel{F, T <: NamedTuple} <: AbstractExplicitContainerLayer{(:layers,)} connection::F layers::T end @@ -219,13 +225,13 @@ function Parallel(connection, layers...) return Parallel(connection, NamedTuple{names}(layers)) end -function (m::Parallel)(x, ps::Union{ComponentArray,NamedTuple}, st::NamedTuple) +function (m::Parallel)(x, ps::Union{ComponentArray, NamedTuple}, st::NamedTuple) return applyparallel(m.layers, m.connection, x, ps, st) end -@generated function applyparallel( - layers::NamedTuple{names}, connection::C, x::T, ps::Union{ComponentArray,NamedTuple}, st::NamedTuple -) where {names,C,T} +@generated function applyparallel(layers::NamedTuple{names}, connection::C, x::T, + ps::Union{ComponentArray, NamedTuple}, + st::NamedTuple) where {names, C, T} N = length(names) y_symbols = [gensym() for _ in 1:(N + 1)] st_symbols = [gensym() for _ in 1:N] @@ -296,7 +302,7 @@ l = BranchLayer( ) ``` """ -struct BranchLayer{T<:NamedTuple} <: AbstractExplicitContainerLayer{(:layers,)} +struct BranchLayer{T <: NamedTuple} <: AbstractExplicitContainerLayer{(:layers,)} layers::T end @@ -305,20 +311,31 @@ function BranchLayer(layers...) return BranchLayer(NamedTuple{names}(layers)) end -(m::BranchLayer)(x, ps::Union{ComponentArray,NamedTuple}, st::NamedTuple) = applybranching(m.layers, x, ps, st) +function (m::BranchLayer)(x, ps::Union{ComponentArray, NamedTuple}, st::NamedTuple) + applybranching(m.layers, x, ps, st) +end -@generated function applybranching( - layers::NamedTuple{names}, x, ps::Union{ComponentArray,NamedTuple}, st::NamedTuple -) where {names} +@generated function applybranching(layers::NamedTuple{names}, x, + ps::Union{ComponentArray, NamedTuple}, + st::NamedTuple) where {names} N = length(names) y_symbols = [gensym() for _ in 1:N] st_symbols = [gensym() for _ in 1:N] calls = [] +<<<<<<< HEAD append!( calls, [:(($(y_symbols[i]), $(st_symbols[i])) = layers[$i](x, ps.$(names[i]), st.$(names[i]))) for i in 1:N] ) push!(calls, :(st = NamedTuple{$names}((($(Tuple(st_symbols)...),))))) push!(calls, :(return tuple($(Tuple(y_symbols)...)), st)) +======= + append!(calls, + [:(($(y_symbols[i]), $(st_symbols[i])) = layers[$i](x, ps.$(names[i]), + st.$(names[i]))) + for i in 1:N]) + append!(calls, [:(st = NamedTuple{$names}((($(Tuple(st_symbols)...),))))]) + append!(calls, [:(return tuple($(Tuple(y_symbols)...)), st)]) +>>>>>>> 862526f (enforce SciMLStyle) return Expr(:block, calls...) end @@ -378,7 +395,7 @@ end * States of each `layer` wrapped in a NamedTuple with `fields = layer_1, layer_2, ..., layer_N` """ -struct PairwiseFusion{F,T<:NamedTuple} <: AbstractExplicitContainerLayer{(:layers,)} +struct PairwiseFusion{F, T <: NamedTuple} <: AbstractExplicitContainerLayer{(:layers,)} connection::F layers::T end @@ -388,18 +405,19 @@ function PairwiseFusion(connection, layers...) return PairwiseFusion(connection, NamedTuple{names}(layers)) end -function (m::PairwiseFusion)(x, ps::Union{ComponentArray,NamedTuple}, st::NamedTuple) +function (m::PairwiseFusion)(x, ps::Union{ComponentArray, NamedTuple}, st::NamedTuple) return applypairwisefusion(m.layers, m.connection, x, ps, st) end -@generated function applypairwisefusion( - layers::NamedTuple{names}, connection::C, x::T, ps::Union{ComponentArray,NamedTuple}, st::NamedTuple -) where {names,C,T} +@generated function applypairwisefusion(layers::NamedTuple{names}, connection::C, x::T, + ps::Union{ComponentArray, NamedTuple}, + st::NamedTuple) where {names, C, T} N = length(names) y_symbols = [gensym() for _ in 1:(N + 1)] st_symbols = [gensym() for _ in 1:N] getinput(i) = T <: Tuple ? :(x[$i]) : :x calls = [:($(y_symbols[N + 1]) = $(getinput(1)))] +<<<<<<< HEAD append!( calls, [ @@ -412,6 +430,18 @@ end ) push!(calls, :(st = NamedTuple{$names}((($(Tuple(st_symbols)...),))))) push!(calls, :(return $(y_symbols[N + 1]), st)) +======= + for i in 1:N + push!(calls, + :(($(y_symbols[i]), $(st_symbols[i])) = layers[$i]($(y_symbols[N + 1]), + ps.$(names[i]), + st.$(names[i])))) + push!(calls, + :($(y_symbols[N + 1]) = connection($(y_symbols[i]), $(getinput(i + 1))))) + end + append!(calls, [:(st = NamedTuple{$names}((($(Tuple(st_symbols)...),))))]) + append!(calls, [:(return $(y_symbols[N + 1]), st)]) +>>>>>>> 862526f (enforce SciMLStyle) return Expr(:block, calls...) end @@ -469,7 +499,7 @@ c = Chain( """ struct Chain{T} <: AbstractExplicitContainerLayer{(:layers,)} layers::T - function Chain(xs...; disable_optimizations::Bool=false) + function Chain(xs...; disable_optimizations::Bool = false) xs = disable_optimizations ? xs : flatten_model(xs) length(xs) == 0 && return NoOpLayer() length(xs) == 1 && return first(xs) @@ -477,17 +507,19 @@ struct Chain{T} <: AbstractExplicitContainerLayer{(:layers,)} layers = NamedTuple{names}(xs) return new{typeof(layers)}(layers) end - Chain(xs::AbstractVector; disable_optimizations::Bool=false) = Chain(xs...; disable_optimizations) + function Chain(xs::AbstractVector; disable_optimizations::Bool = false) + Chain(xs...; disable_optimizations) + end end -function flatten_model(layers::Union{AbstractVector,Tuple}) +function flatten_model(layers::Union{AbstractVector, Tuple}) new_layers = [] for l in layers f = flatten_model(l) if f isa Tuple || f isa AbstractVector append!(new_layers, f) elseif f isa Function - if !hasmethod(f, (Any, Union{ComponentArray,NamedTuple}, NamedTuple)) + if !hasmethod(f, (Any, Union{ComponentArray, NamedTuple}, NamedTuple)) push!(new_layers, WrappedFunction(f)) else push!(new_layers, f) @@ -505,15 +537,18 @@ end flatten_model(x) = x -(c::Chain)(x, ps::Union{ComponentArray,NamedTuple}, st::NamedTuple) = applychain(c.layers, x, ps, st) +function (c::Chain)(x, ps::Union{ComponentArray, NamedTuple}, st::NamedTuple) + applychain(c.layers, x, ps, st) +end -@generated function applychain( - layers::NamedTuple{fields}, x, ps::Union{ComponentArray,NamedTuple}, st::NamedTuple{fields} -) where {fields} +@generated function applychain(layers::NamedTuple{fields}, x, + ps::Union{ComponentArray, NamedTuple}, + st::NamedTuple{fields}) where {fields} N = length(fields) x_symbols = [gensym() for _ in 1:N] st_symbols = [gensym() for _ in 1:N] calls = [:(($(x_symbols[1]), $(st_symbols[1])) = layers[1](x, ps.layer_1, st.layer_1))] +<<<<<<< HEAD append!( calls, [ @@ -523,6 +558,15 @@ flatten_model(x) = x ) push!(calls, :(st = NamedTuple{$fields}((($(Tuple(st_symbols)...),))))) push!(calls, :(return $(x_symbols[N]), st)) +======= + append!(calls, + [:(($(x_symbols[i]), $(st_symbols[i])) = layers[$i]($(x_symbols[i - 1]), + ps.$(fields[i]), + st.$(fields[i]))) + for i in 2:N]) + append!(calls, [:(st = NamedTuple{$fields}((($(Tuple(st_symbols)...),))))]) + append!(calls, [:(return $(x_symbols[N]), st)]) +>>>>>>> 862526f (enforce SciMLStyle) return Expr(:block, calls...) end @@ -559,7 +603,7 @@ Create a traditional fully connected layer, whose forward pass is given by: `y = * `weight`: Weight Matrix of size `out_dims × in_dims` * `bias`: Bias of size `out_dims × 1` (present if `bias=true`) """ -struct Dense{bias,F1,F2,F3} <: AbstractExplicitLayer +struct Dense{bias, F1, F2, F3} <: AbstractExplicitLayer activation::F1 in_dims::Int out_dims::Int @@ -574,53 +618,66 @@ function Base.show(io::IO, d::Dense{bias}) where {bias} return print(io, ")") end -function Dense(mapping::Pair{<:Int,<:Int}, activation=identity; init_weight=glorot_uniform, init_bias=zeros32, bias::Bool=true) - return Dense(first(mapping), last(mapping), activation; init_weight=init_weight, init_bias=init_bias, bias=bias) +function Dense(mapping::Pair{<:Int, <:Int}, activation = identity; + init_weight = glorot_uniform, init_bias = zeros32, bias::Bool = true) + return Dense(first(mapping), last(mapping), activation; init_weight = init_weight, + init_bias = init_bias, bias = bias) end -function Dense(in_dims::Int, out_dims::Int, activation=identity; init_weight=glorot_uniform, init_bias=zeros32, bias::Bool=true) +function Dense(in_dims::Int, out_dims::Int, activation = identity; + init_weight = glorot_uniform, init_bias = zeros32, bias::Bool = true) activation = NNlib.fast_act(activation) - return Dense{bias,typeof(activation),typeof(init_weight),typeof(init_bias)}(activation, in_dims, out_dims, init_weight, init_bias) + return Dense{bias, typeof(activation), typeof(init_weight), typeof(init_bias)}(activation, + in_dims, + out_dims, + init_weight, + init_bias) end function initialparameters(rng::AbstractRNG, d::Dense{bias}) where {bias} if bias - return (weight=d.init_weight(rng, d.out_dims, d.in_dims), bias=d.init_bias(rng, d.out_dims, 1)) + return (weight = d.init_weight(rng, d.out_dims, d.in_dims), + bias = d.init_bias(rng, d.out_dims, 1)) else - return (weight=d.init_weight(rng, d.out_dims, d.in_dims),) + return (weight = d.init_weight(rng, d.out_dims, d.in_dims),) end end -parameterlength(d::Dense{bias}) where {bias} = bias ? d.out_dims * (d.in_dims + 1) : d.out_dims * d.in_dims +function parameterlength(d::Dense{bias}) where {bias} + bias ? d.out_dims * (d.in_dims + 1) : d.out_dims * d.in_dims +end statelength(d::Dense) = 0 -@inline function (d::Dense{false})(x::AbstractArray, ps::Union{ComponentArray,NamedTuple}, st::NamedTuple) +@inline function (d::Dense{false})(x::AbstractArray, ps::Union{ComponentArray, NamedTuple}, + st::NamedTuple) return applyactivation(d.activation, ps.weight * x), st end -@inline function (d::Dense{false,typeof(identity)})( - x::AbstractArray, ps::Union{ComponentArray,NamedTuple}, st::NamedTuple -) +@inline function (d::Dense{false, typeof(identity)})(x::AbstractArray, + ps::Union{ComponentArray, NamedTuple}, + st::NamedTuple) return ps.weight * x, st end -@inline function (d::Dense{true})(x::AbstractArray, ps::Union{ComponentArray,NamedTuple}, st::NamedTuple) +@inline function (d::Dense{true})(x::AbstractArray, ps::Union{ComponentArray, NamedTuple}, + st::NamedTuple) return applyactivation(d.activation, elementwise_add(ps.weight * x, ps.bias)), st end -@inline function (d::Dense{true,typeof(identity)})( - x::AbstractArray, ps::Union{ComponentArray,NamedTuple}, st::NamedTuple -) +@inline function (d::Dense{true, typeof(identity)})(x::AbstractArray, + ps::Union{ComponentArray, NamedTuple}, + st::NamedTuple) return elementwise_add(ps.weight * x, ps.bias), st end -@inline function (d::Dense{true})(x::AbstractVector, ps::Union{ComponentArray,NamedTuple}, st::NamedTuple) +@inline function (d::Dense{true})(x::AbstractVector, ps::Union{ComponentArray, NamedTuple}, + st::NamedTuple) return applyactivation(d.activation, elementwise_add(ps.weight * x, vec(ps.bias))), st end -@inline function (d::Dense{true,typeof(identity)})( - x::AbstractVector, ps::Union{ComponentArray,NamedTuple}, st::NamedTuple -) +@inline function (d::Dense{true, typeof(identity)})(x::AbstractVector, + ps::Union{ComponentArray, NamedTuple}, + st::NamedTuple) return elementwise_add(ps.weight * x, vec(ps.bias)), st end @@ -654,7 +711,7 @@ Create a Sparsely Connected Layer with a very specific structure (only Diagonal * `weight`: Weight Vector of size `(dims,)` * `bias`: Bias of size `(dims,)` """ -struct Scale{bias,F1,D,F2,F3} <: AbstractExplicitLayer +struct Scale{bias, F1, D, F2, F3} <: AbstractExplicitLayer activation::F1 dims::D init_weight::F2 @@ -667,31 +724,42 @@ function Base.show(io::IO, d::Scale) return print(io, ")") end -function Scale(dims, activation=identity; init_weight=glorot_uniform, init_bias=zeros32, bias::Bool=true) +function Scale(dims, activation = identity; init_weight = glorot_uniform, + init_bias = zeros32, bias::Bool = true) activation = NNlib.fast_act(activation) - return Scale{bias,typeof(activation),typeof(dims),typeof(init_weight),typeof(init_bias)}(activation, dims, init_weight, init_bias) + return Scale{bias, typeof(activation), typeof(dims), typeof(init_weight), + typeof(init_bias)}(activation, dims, init_weight, init_bias) end function initialparameters(rng::AbstractRNG, d::Scale{true}) - return (weight=d.init_weight(rng, d.dims), bias=d.init_bias(rng, d.dims)) + return (weight = d.init_weight(rng, d.dims), bias = d.init_bias(rng, d.dims)) +end +function initialparameters(rng::AbstractRNG, d::Scale{false}) + (weight = d.init_weight(rng, d.dims),) end -initialparameters(rng::AbstractRNG, d::Scale{false}) = (weight=d.init_weight(rng, d.dims),) parameterlength(d::Scale{bias}) where {bias} = (1 + bias) * d.dims statelength(d::Scale) = 0 -function (d::Scale{true})(x::AbstractArray, ps::Union{ComponentArray,NamedTuple}, st::NamedTuple) - return applyactivation(d.activation, elementwise_add(elementwise_mul(ps.weight, x), ps.bias)), st +function (d::Scale{true})(x::AbstractArray, ps::Union{ComponentArray, NamedTuple}, + st::NamedTuple) + return applyactivation(d.activation, + elementwise_add(elementwise_mul(ps.weight, x), ps.bias)), st end -function (d::Scale{true,typeof(identity)})(x::AbstractArray, ps::Union{ComponentArray,NamedTuple}, st::NamedTuple) +function (d::Scale{true, typeof(identity)})(x::AbstractArray, + ps::Union{ComponentArray, NamedTuple}, + st::NamedTuple) return elementwise_add(elementwise_mul(ps.weight, x), ps.bias), st end -function (d::Scale{false})(x::AbstractArray, ps::Union{ComponentArray,NamedTuple}, st::NamedTuple) +function (d::Scale{false})(x::AbstractArray, ps::Union{ComponentArray, NamedTuple}, + st::NamedTuple) return applyactivation(d.activation, elementwise_mul(ps.weight, x)), st end -function (d::Scale{false,typeof(identity)})(x::AbstractArray, ps::Union{ComponentArray,NamedTuple}, st::NamedTuple) +function (d::Scale{false, typeof(identity)})(x::AbstractArray, + ps::Union{ComponentArray, NamedTuple}, + st::NamedTuple) return elementwise_mul(ps.weight, x), st end diff --git a/src/layers/conv.jl b/src/layers/conv.jl index 4a82f65b6..0e50188cf 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -42,55 +42,67 @@ Image data should be stored in WHCN order (width, height, channels, batch). In o * `weight`: Convolution kernel * `bias`: Bias (present if `bias=true`) """ -struct Conv{N,bias,M,F1,F2} <: AbstractExplicitLayer +struct Conv{N, bias, M, F1, F2} <: AbstractExplicitLayer activation::F1 in_chs::Int out_chs::Int - kernel_size::NTuple{N,Int} - stride::NTuple{N,Int} - pad::NTuple{M,Int} - dilation::NTuple{N,Int} + kernel_size::NTuple{N, Int} + stride::NTuple{N, Int} + pad::NTuple{M, Int} + dilation::NTuple{N, Int} groups::Int init_weight::F2 end -function Conv( - k::NTuple{N,Integer}, - ch::Pair{<:Integer,<:Integer}, - activation=identity; - init_weight=glorot_uniform, - stride=1, - pad=0, - dilation=1, - groups=1, - bias=true, -) where {N} +function Conv(k::NTuple{N, Integer}, + ch::Pair{<:Integer, <:Integer}, + activation = identity; + init_weight = glorot_uniform, + stride = 1, + pad = 0, + dilation = 1, + groups = 1, + bias = true) where {N} stride = expand(Val(N), stride) dilation = expand(Val(N), dilation) pad = calc_padding(Conv, pad, k, dilation, stride) activation = NNlib.fast_act(activation) - return Conv{N,bias,length(pad),typeof(activation),typeof(init_weight)}( - activation, first(ch), last(ch), k, stride, pad, dilation, groups, init_weight - ) + return Conv{N, bias, length(pad), typeof(activation), typeof(init_weight)}(activation, + first(ch), + last(ch), k, + stride, pad, + dilation, + groups, + init_weight) end -function initialparameters(rng::AbstractRNG, c::Conv{N,bias}) where {N,bias} - weight = convfilter(rng, c.kernel_size, c.in_chs => c.out_chs; init=c.init_weight, groups=c.groups) - return bias ? (weight=weight, bias=zeros(eltype(weight), ntuple(_ -> 1, N)..., c.out_chs, 1)) : (weight=weight,) +function initialparameters(rng::AbstractRNG, c::Conv{N, bias}) where {N, bias} + weight = convfilter(rng, c.kernel_size, c.in_chs => c.out_chs; init = c.init_weight, + groups = c.groups) + return bias ? + (weight = weight, + bias = zeros(eltype(weight), ntuple(_ -> 1, N)..., c.out_chs, 1)) : + (weight = weight,) end -function parameterlength(c::Conv{N,bias}) where {N,bias} +function parameterlength(c::Conv{N, bias}) where {N, bias} return prod(c.kernel_size) * c.in_chs * c.out_chs ÷ c.groups + (bias ? c.out_chs : 0) end -@inline function (c::Conv{N,false})(x::AbstractArray, ps::Union{ComponentArray,NamedTuple}, st::NamedTuple) where {N} - cdims = DenseConvDims(x, ps.weight; stride=c.stride, padding=c.pad, dilation=c.dilation, groups=c.groups) +@inline function (c::Conv{N, false})(x::AbstractArray, + ps::Union{ComponentArray, NamedTuple}, + st::NamedTuple) where {N} + cdims = DenseConvDims(x, ps.weight; stride = c.stride, padding = c.pad, + dilation = c.dilation, groups = c.groups) return applyactivation(c.activation, conv_wrapper(x, ps.weight, cdims)), st end -@inline function (c::Conv{N,true})(x::AbstractArray, ps::Union{ComponentArray,NamedTuple}, st::NamedTuple) where {N} - cdims = DenseConvDims(x, ps.weight; stride=c.stride, padding=c.pad, dilation=c.dilation, groups=c.groups) - return applyactivation(c.activation, elementwise_add(conv_wrapper(x, ps.weight, cdims), ps.bias)), st +@inline function (c::Conv{N, true})(x::AbstractArray, ps::Union{ComponentArray, NamedTuple}, + st::NamedTuple) where {N} + cdims = DenseConvDims(x, ps.weight; stride = c.stride, padding = c.pad, + dilation = c.dilation, groups = c.groups) + return applyactivation(c.activation, + elementwise_add(conv_wrapper(x, ps.weight, cdims), ps.bias)), st end function Base.show(io::IO, l::Conv) @@ -100,7 +112,7 @@ function Base.show(io::IO, l::Conv) return print(io, ")") end -function _print_conv_opt(io::IO, l::Conv{N,bias}) where {N,bias} +function _print_conv_opt(io::IO, l::Conv{N, bias}) where {N, bias} l.activation == identity || print(io, ", ", l.activation) all(==(0), l.pad) || print(io, ", pad=", _maybetuple_string(l.pad)) all(==(1), l.stride) || print(io, ", stride=", _maybetuple_string(l.stride)) @@ -139,20 +151,20 @@ Max pooling layer, which replaces all pixels in a block of size `window` with th See also [`Conv`](@ref), [`MeanPool`](@ref), [`GlobalMaxPool`](@ref), [`AdaptiveMaxPool`](@ref) """ -struct MaxPool{N,M} <: AbstractExplicitLayer - k::NTuple{N,Int} - pad::NTuple{M,Int} - stride::NTuple{N,Int} +struct MaxPool{N, M} <: AbstractExplicitLayer + k::NTuple{N, Int} + pad::NTuple{M, Int} + stride::NTuple{N, Int} end -function MaxPool(k::NTuple{N,Integer}; pad=0, stride=k) where {N} +function MaxPool(k::NTuple{N, Integer}; pad = 0, stride = k) where {N} stride = expand(Val(N), stride) pad = calc_padding(MaxPool, pad, k, 1, stride) - return MaxPool{N,length(pad)}(k, pad, stride) + return MaxPool{N, length(pad)}(k, pad, stride) end -function (m::MaxPool{N,M})(x, ps, st::NamedTuple) where {N,M} - pdims = PoolDims(x, m.k; padding=m.pad, stride=m.stride) +function (m::MaxPool{N, M})(x, ps, st::NamedTuple) where {N, M} + pdims = PoolDims(x, m.k; padding = m.pad, stride = m.stride) return maxpool(x, pdims), st end @@ -192,20 +204,20 @@ Mean pooling layer, which replaces all pixels in a block of size `window` with t See also [`Conv`](@ref), [`MaxPool`](@ref), [`GlobalMeanPool`](@ref), [`AdaptiveMeanPool`](@ref) """ -struct MeanPool{N,M} <: AbstractExplicitLayer - k::NTuple{N,Int} - pad::NTuple{M,Int} - stride::NTuple{N,Int} +struct MeanPool{N, M} <: AbstractExplicitLayer + k::NTuple{N, Int} + pad::NTuple{M, Int} + stride::NTuple{N, Int} end -function MeanPool(k::NTuple{N,Integer}; pad=0, stride=k) where {N} +function MeanPool(k::NTuple{N, Integer}; pad = 0, stride = k) where {N} stride = expand(Val(N), stride) pad = calc_padding(MeanPool, pad, k, 1, stride) - return MeanPool{N,length(pad)}(k, pad, stride) + return MeanPool{N, length(pad)}(k, pad, stride) end -function (m::MeanPool{N,M})(x, ps, st::NamedTuple) where {N,M} - pdims = PoolDims(x, m.k; padding=m.pad, stride=m.stride) +function (m::MeanPool{N, M})(x, ps, st::NamedTuple) where {N, M} + pdims = PoolDims(x, m.k; padding = m.pad, stride = m.stride) return meanpool(x, pdims), st end @@ -255,43 +267,45 @@ Currently supported upsampling `mode`s and corresponding NNlib's methods are: * Upsampled Input of size `size` or of size `(I_1 × scale[1], ..., I_N × scale[N], C, N)` * Empty `NamedTuple()` """ -struct Upsample{mode,S,T} <: AbstractExplicitLayer +struct Upsample{mode, S, T} <: AbstractExplicitLayer scale::S size::T end -function Upsample(mode::Symbol=:nearest; scale=nothing, size=nothing) - mode in [:nearest, :bilinear, :trilinear] || throw(ArgumentError("mode=:$mode is not supported.")) +function Upsample(mode::Symbol = :nearest; scale = nothing, size = nothing) + mode in [:nearest, :bilinear, :trilinear] || + throw(ArgumentError("mode=:$mode is not supported.")) if !(isnothing(scale) ⊻ isnothing(size)) throw(ArgumentError("Either scale or size should be specified (but not both).")) end - return Upsample{mode,typeof(scale),typeof(size)}(scale, size) + return Upsample{mode, typeof(scale), typeof(size)}(scale, size) end -Upsample(scale, mode::Symbol=:nearest) = Upsample(mode; scale) +Upsample(scale, mode::Symbol = :nearest) = Upsample(mode; scale) function (m::Upsample{:nearest})(x::AbstractArray, ps, st::NamedTuple) return NNlib.upsample_nearest(x, m.scale), st end -function (m::Upsample{:nearest,Int})(x::AbstractArray{T,N}, ps, st::NamedTuple) where {T,N} +function (m::Upsample{:nearest, Int})(x::AbstractArray{T, N}, ps, + st::NamedTuple) where {T, N} return NNlib.upsample_nearest(x, ntuple(i -> m.scale, N - 2)), st end -function (m::Upsample{:nearest,Nothing})(x::AbstractArray, ps, st::NamedTuple) - return NNlib.upsample_nearest(x; size=m.size), st +function (m::Upsample{:nearest, Nothing})(x::AbstractArray, ps, st::NamedTuple) + return NNlib.upsample_nearest(x; size = m.size), st end function (m::Upsample{:bilinear})(x::AbstractArray, ps, st::NamedTuple) return NNlib.upsample_bilinear(x, m.scale), st end -function (m::Upsample{:bilinear,Nothing})(x::AbstractArray, ps, st::NamedTuple) - return NNlib.upsample_bilinear(x; size=m.size), st +function (m::Upsample{:bilinear, Nothing})(x::AbstractArray, ps, st::NamedTuple) + return NNlib.upsample_bilinear(x; size = m.size), st end function (m::Upsample{:trilinear})(x::AbstractArray, ps, st::NamedTuple) return NNlib.upsample_trilinear(x, m.scale), st end -function (m::Upsample{:trilinear,Nothing})(x::AbstractArray, ps, st::NamedTuple) - return NNlib.upsample_trilinear(x; size=m.size), st +function (m::Upsample{:trilinear, Nothing})(x::AbstractArray, ps, st::NamedTuple) + return NNlib.upsample_trilinear(x; size = m.size), st end function Base.show(io::IO, u::Upsample{mode}) where {mode} @@ -366,12 +380,12 @@ Adaptive Max Pooling layer. Calculates the necessary window size such that its o See also [`MaxPool`](@ref), [`AdaptiveMeanPool`](@ref). """ -struct AdaptiveMaxPool{S,O} <: AbstractExplicitLayer - out::NTuple{O,Int} - AdaptiveMaxPool(out::NTuple{O,Int}) where {O} = new{O + 2,O}(out) +struct AdaptiveMaxPool{S, O} <: AbstractExplicitLayer + out::NTuple{O, Int} + AdaptiveMaxPool(out::NTuple{O, Int}) where {O} = new{O + 2, O}(out) end -function (a::AdaptiveMaxPool{S})(x::AbstractArray{T,S}, ps, st::NamedTuple) where {S,T} +function (a::AdaptiveMaxPool{S})(x::AbstractArray{T, S}, ps, st::NamedTuple) where {S, T} pdims = compute_adaptive_pooling_dims(x, a.out) return maxpool(x, pdims), st end @@ -400,12 +414,12 @@ Adaptive Mean Pooling layer. Calculates the necessary window size such that its See also [`MeanPool`](@ref), [`AdaptiveMaxPool`](@ref). """ -struct AdaptiveMeanPool{S,O} <: AbstractExplicitLayer - out::NTuple{O,Int} - AdaptiveMeanPool(out::NTuple{O,Int}) where {O} = new{O + 2,O}(out) +struct AdaptiveMeanPool{S, O} <: AbstractExplicitLayer + out::NTuple{O, Int} + AdaptiveMeanPool(out::NTuple{O, Int}) where {O} = new{O + 2, O}(out) end -function (a::AdaptiveMeanPool{S})(x::AbstractArray{T,S}, ps, st::NamedTuple) where {S,T} +function (a::AdaptiveMeanPool{S})(x::AbstractArray{T, S}, ps, st::NamedTuple) where {S, T} pdims = compute_adaptive_pooling_dims(x, a.out) return meanpool(x, pdims), st end diff --git a/src/layers/display.jl b/src/layers/display.jl index 9d12dfd11..9817407d7 100644 --- a/src/layers/display.jl +++ b/src/layers/display.jl @@ -8,7 +8,7 @@ function Base.show(io::IO, ::MIME"text/plain", x::AbstractExplicitContainerLayer end end -function _big_show(io::IO, obj, indent::Int=0, name=nothing) +function _big_show(io::IO, obj, indent::Int = 0, name = nothing) pre, post = "(", ")" children = _get_children(obj) if obj isa Function @@ -21,7 +21,7 @@ function _big_show(io::IO, obj, indent::Int=0, name=nothing) for k in Base.keys(obj) _big_show(io, obj.layers[k], indent + 4, k) end - elseif obj isa Parallel{<:Any,<:NamedTuple} + elseif obj isa Parallel{<:Any, <:NamedTuple} if obj.connection !== nothing _big_show(io, obj.connection, indent + 4) end @@ -65,7 +65,9 @@ _show_leaflike(::Tuple{Vararg{<:AbstractArray}}) = true # e.g. parameters of LS function _get_children(l::AbstractExplicitContainerLayer{names}) where {names} return NamedTuple{names}(getfield.((l,), names)) end -_get_children(p::Parallel) = p.connection === nothing ? p.layers : (p.connection, p.layers...) +function _get_children(p::Parallel) + p.connection === nothing ? p.layers : (p.connection, p.layers...) +end _get_children(s::SkipConnection) = (s.layers, s.connection) _get_children(s::WeightNorm) = (s.layer,) _get_children(::Any) = () @@ -78,17 +80,19 @@ function Base.show(io::IO, ::MIME"text/plain", x::AbstractExplicitLayer) end end -function _layer_show(io::IO, layer, indent::Int=0, name=nothing) +function _layer_show(io::IO, layer, indent::Int = 0, name = nothing) _str = isnothing(name) ? "" : "$name = " - str = _str * sprint(show, layer; context=io) + str = _str * sprint(show, layer; context = io) print(io, " "^indent, str, indent == 0 ? "" : ",") paramlength = parameterlength(layer) if paramlength > 0 print(io, " "^max(2, (indent == 0 ? 20 : 39) - indent - length(str))) - printstyled(io, "# ", underscorise(paramlength), " parameters"; color=:light_black) + printstyled(io, "# ", underscorise(paramlength), " parameters"; + color = :light_black) nonparam = statelength(layer) if nonparam > 0 - printstyled(io, ", plus ", underscorise(nonparam), indent == 0 ? " non-trainable" : ""; color=:light_black) + printstyled(io, ", plus ", underscorise(nonparam), + indent == 0 ? " non-trainable" : ""; color = :light_black) end end return indent == 0 || println(io) @@ -100,26 +104,28 @@ function _big_finale(io::IO, m) pars = underscorise(paramlength) bytes = Base.format_bytes(Base.summarysize(m)) nonparam = underscorise(nonparamlength) - printstyled(io, " "^08, "# Total: "; color=:light_black) + printstyled(io, " "^08, "# Total: "; color = :light_black) println(io, pars, " parameters,") - printstyled(io, " "^10, "# plus "; color=:light_black) + printstyled(io, " "^10, "# plus "; color = :light_black) print(io, nonparam, " states, ") - printstyled(io, "summarysize "; color=:light_black) + printstyled(io, "summarysize "; color = :light_black) print(io, bytes, ".") return end # utility functions -underscorise(n::Integer) = join(reverse(join.(reverse.(Iterators.partition(digits(n), 3)))), '_') +function underscorise(n::Integer) + join(reverse(join.(reverse.(Iterators.partition(digits(n), 3)))), '_') +end function _nan_show(io::IO, x) if !isempty(x) && _all(iszero, x) - printstyled(io, " (all zero)"; color=:cyan) + printstyled(io, " (all zero)"; color = :cyan) elseif _any(isnan, x) - printstyled(io, " (some NaN)"; color=:red) + printstyled(io, " (some NaN)"; color = :red) elseif _any(isinf, x) - printstyled(io, " (some Inf)"; color=:red) + printstyled(io, " (some Inf)"; color = :red) end end diff --git a/src/layers/dropout.jl b/src/layers/dropout.jl index 2679d7e38..6b915d196 100644 --- a/src/layers/dropout.jl +++ b/src/layers/dropout.jl @@ -29,7 +29,7 @@ Call [`Lux.testmode`](@ref) to switch to test mode. See also [`VariationalHiddenDropout`](@ref) """ -struct Dropout{T,D} <: AbstractExplicitLayer +struct Dropout{T, D} <: AbstractExplicitLayer p::T dims::D end @@ -37,17 +37,17 @@ end function initialstates(rng::AbstractRNG, ::Dropout) # FIXME: Take PRNGs seriously randn(rng, 1) - return (rng=replicate(rng), training=Val(true)) + return (rng = replicate(rng), training = Val(true)) end -function Dropout(p; dims=:) +function Dropout(p; dims = :) @assert 0 ≤ p ≤ 1 return Dropout(p, dims) end function (d::Dropout{T})(x::AbstractArray{T}, ps, st::NamedTuple) where {T} y, _, rng = dropout(st.rng, x, d.p, d.dims, st.training) - return y, merge(st, (rng=rng,)) + return y, merge(st, (rng = rng,)) end function Base.show(io::IO, d::Dropout) @@ -89,7 +89,7 @@ Call [`Lux.testmode`](@ref) to switch to test mode. See also [`Dropout`](@ref) """ -struct VariationalHiddenDropout{T,D} <: AbstractExplicitLayer +struct VariationalHiddenDropout{T, D} <: AbstractExplicitLayer p::T dims::D end @@ -97,17 +97,19 @@ end function initialstates(rng::AbstractRNG, ::VariationalHiddenDropout) # FIXME: Take PRNGs seriously randn(rng, 1) - return (rng=replicate(rng), training=Val(true), update_mask=Val(true), mask=nothing) + return (rng = replicate(rng), training = Val(true), update_mask = Val(true), + mask = nothing) end -function VariationalHiddenDropout(p; dims=:) +function VariationalHiddenDropout(p; dims = :) @assert 0 ≤ p ≤ 1 return VariationalHiddenDropout(p, dims) end function (d::VariationalHiddenDropout{T})(x::AbstractArray{T}, ps, st::NamedTuple) where {T} - y, mask, rng, update_mask = dropout(st.rng, x, st.mask, d.p, d.dims, st.training, st.update_mask) - return y, merge(st, (mask=mask, rng=rng, update_mask=update_mask)) + y, mask, rng, update_mask = dropout(st.rng, x, st.mask, d.p, d.dims, st.training, + st.update_mask) + return y, merge(st, (mask = mask, rng = rng, update_mask = update_mask)) end function Base.show(io::IO, d::VariationalHiddenDropout) diff --git a/src/layers/normalize.jl b/src/layers/normalize.jl index 806089c1f..58e5c5b09 100644 --- a/src/layers/normalize.jl +++ b/src/layers/normalize.jl @@ -1,9 +1,4 @@ -abstract type AbstractNormalizationLayer{affine,track_stats} <: AbstractExplicitLayer end - -get_reduce_dims(::AbstractNormalizationLayer, ::AbstractArray) = error("Not Implemented Yet!!") - -get_proper_shape(::AbstractNormalizationLayer, ::AbstractArray, y::Nothing, args...) = y -get_proper_shape(::AbstractNormalizationLayer, x::AbstractArray{T,N}, y::AbstractArray{T,N}, args...) where {T,N} = y +abstract type AbstractNormalizationLayer{affine, track_stats} <: AbstractExplicitLayer end """ BatchNorm(chs::Integer, activation=identity; init_bias=zeros32, init_scale=ones32, affine=true, track_stats=true, epsilon=1f-5, momentum=0.1f0) @@ -67,7 +62,8 @@ m = Chain( See also [`GroupNorm`](@ref) """ -struct BatchNorm{affine,track_stats,F1,F2,F3,N} <: AbstractNormalizationLayer{affine,track_stats} +struct BatchNorm{affine, track_stats, F1, F2, F3, N} <: + AbstractNormalizationLayer{affine, track_stats} activation::F1 epsilon::N momentum::N @@ -76,65 +72,66 @@ struct BatchNorm{affine,track_stats,F1,F2,F3,N} <: AbstractNormalizationLayer{af init_scale::F3 end -function BatchNorm( - chs::Int, - activation=identity; - init_bias=zeros32, - init_scale=ones32, - affine::Bool=true, - track_stats::Bool=true, - epsilon=1.0f-5, - momentum=0.1f0, -) +function BatchNorm(chs::Int, + activation = identity; + init_bias = zeros32, + init_scale = ones32, + affine::Bool = true, + track_stats::Bool = true, + epsilon = 1.0f-5, + momentum = 0.1f0) activation = NNlib.fast_act(activation) - return BatchNorm{affine,track_stats,typeof(activation),typeof(init_bias),typeof(init_scale),typeof(epsilon)}( - activation, epsilon, momentum, chs, init_bias, init_scale - ) + return BatchNorm{affine, track_stats, typeof(activation), typeof(init_bias), + typeof(init_scale), typeof(epsilon)}(activation, epsilon, momentum, + chs, init_bias, init_scale) end function initialparameters(rng::AbstractRNG, l::BatchNorm{affine}) where {affine} - return affine ? (scale=l.init_scale(rng, l.chs), bias=l.init_bias(rng, l.chs)) : NamedTuple() + return affine ? (scale = l.init_scale(rng, l.chs), bias = l.init_bias(rng, l.chs)) : + NamedTuple() end -function initialstates(rng::AbstractRNG, l::BatchNorm{affine,track_stats}) where {affine,track_stats} +function initialstates(rng::AbstractRNG, + l::BatchNorm{affine, track_stats}) where {affine, track_stats} return if track_stats - (running_mean=zeros32(rng, l.chs), running_var=ones32(rng, l.chs), training=Val(true)) + (running_mean = zeros32(rng, l.chs), running_var = ones32(rng, l.chs), + training = Val(true)) else - (running_mean=nothing, running_var=nothing, training=Val(true)) + (running_mean = nothing, running_var = nothing, training = Val(true)) end end parameterlength(l::BatchNorm{affine}) where {affine} = affine ? (l.chs * 2) : 0 -statelength(l::BatchNorm{affine,track_stats}) where {affine,track_stats} = (track_stats ? 2 * l.chs : 0) + 1 - -function get_proper_shape(::BatchNorm, x::AbstractArray{T,N}, y::AbstractVector) where {T,N} - return reshape(y, ntuple(i -> i == N - 1 ? length(y) : 1, N)...) +function statelength(l::BatchNorm{affine, track_stats}) where {affine, track_stats} + (track_stats ? 2 * l.chs : 0) + 1 end -function (BN::BatchNorm)(x::AbstractArray{T,N}, ps, st::NamedTuple) where {T,N} +function (BN::BatchNorm)(x::AbstractArray{T, N}, ps, st::NamedTuple) where {T, N} @assert size(x, N - 1) == BN.chs - @assert !istraining(st) || size(x, N) > 1 "During `training`, `BatchNorm` can't handle Batch Size == 1" - - x_normalized, xmean, xvar = normalization( - x, - st.running_mean, - st.running_var, - ps.scale, - ps.bias, - BN.activation, - collect([1:(N - 2); N]), - st.training, - BN.momentum, - BN.epsilon, - ) - - st = merge(st, (running_mean=xmean, running_var=xvar)) + @assert !istraining(st)||size(x, N) > 1 "During `training`, `BatchNorm` can't handle Batch Size == 1" + + x_normalized, xmean, xvar = normalization(x, + st.running_mean, + st.running_var, + ps.scale, + ps.bias, + BN.activation, + collect([1:(N - 2); N]), + st.training, + BN.momentum, + BN.epsilon) + + st = merge(st, (running_mean = xmean, running_var = xvar)) return x_normalized, st end -function (BN::BatchNorm{affine,track_stats})( - x::Union{CuArray{T,2},CuArray{T,4},CuArray{T,5}}, ps, st::NamedTuple -) where {T<:Union{Float32,Float64},affine,track_stats} +function (BN::BatchNorm{affine, track_stats})(x::Union{CuArray{T, 2}, CuArray{T, 4}, + CuArray{T, 5}}, ps, + st::NamedTuple) where { + T <: + Union{Float32, Float64 + }, affine, + track_stats} # NNlibCUDA silently updates running_mean and running_var so copying them if istraining(st) running_mean2 = track_stats ? copy(st.running_mean) : nothing @@ -146,30 +143,27 @@ function (BN::BatchNorm{affine,track_stats})( else N = ndims(x) reduce_dims = collect([1:(N - 2); N]) - running_mean2 = mean(x; dims=reduce_dims) - running_var2 = var(x; mean=running_mean2, dims=reduce_dims, corrected=false) + running_mean2 = mean(x; dims = reduce_dims) + running_var2 = var(x; mean = running_mean2, dims = reduce_dims, + corrected = false) end end - res = applyactivation( - BN.activation, - batchnorm( - affine ? ps.scale : nothing, - affine ? ps.bias : nothing, - x, - running_mean2, - running_var2, - BN.momentum; - eps=BN.epsilon, - training=istraining(st), - ), - ) + res = applyactivation(BN.activation, + batchnorm(affine ? ps.scale : nothing, + affine ? ps.bias : nothing, + x, + running_mean2, + running_var2, + BN.momentum; + eps = BN.epsilon, + training = istraining(st))) if track_stats - st = merge(st, (running_mean=running_mean2, running_var=running_var2)) + st = merge(st, (running_mean = running_mean2, running_var = running_var2)) end return res, st end -function Base.show(io::IO, l::BatchNorm{affine,track_stats}) where {affine,track_stats} +function Base.show(io::IO, l::BatchNorm{affine, track_stats}) where {affine, track_stats} print(io, "BatchNorm($(l.chs)") (l.activation == identity) || print(io, ", $(l.activation)") affine || print(io, ", affine=false") @@ -241,7 +235,8 @@ m = Chain( See also [`BatchNorm`](@ref) """ -struct GroupNorm{affine,track_stats,F1,F2,F3,N} <: AbstractNormalizationLayer{affine,track_stats} +struct GroupNorm{affine, track_stats, F1, F2, F3, N} <: + AbstractNormalizationLayer{affine, track_stats} activation::F1 epsilon::N momentum::N @@ -251,64 +246,66 @@ struct GroupNorm{affine,track_stats,F1,F2,F3,N} <: AbstractNormalizationLayer{af groups::Int end -function GroupNorm( - chs::Int, - groups::Int, - activation=identity; - init_bias=zeros32, - init_scale=ones32, - affine::Bool=true, - track_stats::Bool=true, - epsilon=1.0f-5, - momentum=0.1f0, -) - @assert chs % groups == 0 "The number of groups ($(groups)) must divide the number of channels ($chs)" +function GroupNorm(chs::Int, + groups::Int, + activation = identity; + init_bias = zeros32, + init_scale = ones32, + affine::Bool = true, + track_stats::Bool = true, + epsilon = 1.0f-5, + momentum = 0.1f0) + @assert chs % groups==0 "The number of groups ($(groups)) must divide the number of channels ($chs)" activation = NNlib.fast_act(activation) - return GroupNorm{affine,track_stats,typeof(activation),typeof(init_bias),typeof(init_scale),typeof(epsilon)}( - activation, epsilon, momentum, chs, init_bias, init_scale, groups - ) + return GroupNorm{affine, track_stats, typeof(activation), typeof(init_bias), + typeof(init_scale), typeof(epsilon)}(activation, epsilon, momentum, + chs, init_bias, init_scale, + groups) end function initialparameters(rng::AbstractRNG, l::GroupNorm{affine}) where {affine} - return affine ? (scale=l.init_scale(rng, l.chs), bias=l.init_bias(rng, l.chs)) : NamedTuple() + return affine ? (scale = l.init_scale(rng, l.chs), bias = l.init_bias(rng, l.chs)) : + NamedTuple() end -function initialstates(rng::AbstractRNG, l::GroupNorm{affine,track_stats}) where {affine,track_stats} +function initialstates(rng::AbstractRNG, + l::GroupNorm{affine, track_stats}) where {affine, track_stats} return if track_stats - (running_mean=zeros32(rng, l.groups), running_var=ones32(rng, l.groups), training=Val(true)) + (running_mean = zeros32(rng, l.groups), running_var = ones32(rng, l.groups), + training = Val(true)) else - (running_mean=nothing, running_var=nothing, training=Val(true)) + (running_mean = nothing, running_var = nothing, training = Val(true)) end end parameterlength(l::GroupNorm{affine}) where {affine} = affine ? (l.chs * 2) : 0 -statelength(l::GroupNorm{affine,track_stats}) where {affine,track_stats} = (track_stats ? 2 * l.groups : 0) + 1 +function statelength(l::GroupNorm{affine, track_stats}) where {affine, track_stats} + (track_stats ? 2 * l.groups : 0) + 1 +end -function (GN::GroupNorm)(x::AbstractArray{T,N}, ps, st::NamedTuple) where {T,N} +function (GN::GroupNorm)(x::AbstractArray{T, N}, ps, st::NamedTuple) where {T, N} sz = size(x) @assert N > 2 @assert sz[N - 1] == GN.chs x_ = reshape(x, sz[1:(N - 2)]..., sz[N - 1] ÷ GN.groups, GN.groups, sz[N]) - x_normalized, xmean, xvar = normalization( - x_, - st.running_mean, - st.running_var, - ps.scale, - ps.bias, - GN.activation, - collect(1:(N - 1)), - st.training, - GN.momentum, - GN.epsilon, - ) - - st = merge(st, (running_mean=xmean, running_var=xvar)) + x_normalized, xmean, xvar = normalization(x_, + st.running_mean, + st.running_var, + ps.scale, + ps.bias, + GN.activation, + collect(1:(N - 1)), + st.training, + GN.momentum, + GN.epsilon) + + st = merge(st, (running_mean = xmean, running_var = xvar)) return reshape(x_normalized, sz), st end -function Base.show(io::IO, l::GroupNorm{affine,track_stats}) where {affine,track_stats} +function Base.show(io::IO, l::GroupNorm{affine, track_stats}) where {affine, track_stats} print(io, "GroupNorm($(l.chs), $(l.groups)") (l.activation == identity) || print(io, ", $(l.activation)") affine || print(io, ", affine=false") @@ -349,18 +346,18 @@ Weight normalization is a reparameterization that decouples the magnitude of a w * Same as that of `layer` """ -struct WeightNorm{which_params,L<:AbstractExplicitLayer,D} <: AbstractExplicitLayer +struct WeightNorm{which_params, L <: AbstractExplicitLayer, D} <: AbstractExplicitLayer layer::L dims::D end -function WeightNorm( - layer::AbstractExplicitLayer, which_params::NTuple{N,Symbol}, dims::Union{Tuple,Nothing}=nothing -) where {N} - return WeightNorm{Val{which_params},typeof(layer),typeof(dims)}(layer, dims) +function WeightNorm(layer::AbstractExplicitLayer, which_params::NTuple{N, Symbol}, + dims::Union{Tuple, Nothing} = nothing) where {N} + return WeightNorm{Val{which_params}, typeof(layer), typeof(dims)}(layer, dims) end -function initialparameters(rng::AbstractRNG, wn::WeightNorm{Val{which_params}}) where {which_params} +function initialparameters(rng::AbstractRNG, + wn::WeightNorm{Val{which_params}}) where {which_params} ps_layer = initialparameters(rng, wn.layer) ps_normalized = [] ps_unnormalized = [] @@ -377,19 +374,20 @@ function initialparameters(rng::AbstractRNG, wn::WeightNorm{Val{which_params}}) end end ps_unnormalized = length(ps_unnormalized) == 0 ? NamedTuple() : (; ps_unnormalized...) - return (normalized=(; ps_normalized...), unnormalized=ps_unnormalized) + return (normalized = (; ps_normalized...), unnormalized = ps_unnormalized) end initialstates(rng::AbstractRNG, wn::WeightNorm) = initialstates(rng, wn.layer) -function (wn::WeightNorm)(x, ps::Union{ComponentArray,NamedTuple}, s::NamedTuple) +function (wn::WeightNorm)(x, ps::Union{ComponentArray, NamedTuple}, s::NamedTuple) _ps = get_normalized_parameters(wn, wn.dims, ps.normalized) return wn.layer(x, merge(_ps, ps.unnormalized), s) end -@inbounds @generated function get_normalized_parameters( - ::WeightNorm{Val{which_params}}, dims::T, ps::Union{ComponentArray,NamedTuple} -) where {T,which_params} +@inbounds @generated function get_normalized_parameters(::WeightNorm{Val{which_params}}, + dims::T, + ps::Union{ComponentArray, NamedTuple + }) where {T, which_params} parameter_names = string.(which_params) v_parameter_names = Symbol.(parameter_names .* "_v") g_parameter_names = Symbol.(parameter_names .* "_g") @@ -405,15 +403,13 @@ end calls = [] for i in 1:length(parameter_names) - push!( - calls, - :( - $(normalized_params_symbol[i]) = - ps.$(v_parameter_names[i]) .* (ps.$(g_parameter_names[i]) ./ $(get_norm_except_invoke(i))) - ), - ) + push!(calls, + :($(normalized_params_symbol[i]) = ps.$(v_parameter_names[i]) .* + (ps.$(g_parameter_names[i]) ./ + $(get_norm_except_invoke(i))))) end - push!(calls, :(return NamedTuple{$(which_params)}(tuple($(Tuple(normalized_params_symbol)...))))) + push!(calls, + :(return NamedTuple{$(which_params)}(tuple($(Tuple(normalized_params_symbol)...))))) return Expr(:block, calls...) end diff --git a/src/layers/recurrent.jl b/src/layers/recurrent.jl index ce624d760..4c492ec6f 100644 --- a/src/layers/recurrent.jl +++ b/src/layers/recurrent.jl @@ -35,7 +35,7 @@ An Elman RNNCell cell with `activation` (typically set to `tanh` or `relu`). * `rng`: Controls the randomness (if any) in the initial state generation """ -struct RNNCell{bias,A,B,W,S} <: AbstractExplicitLayer +struct RNNCell{bias, A, B, W, S} <: AbstractExplicitLayer activation::A in_dims::Int out_dims::Int @@ -44,26 +44,22 @@ struct RNNCell{bias,A,B,W,S} <: AbstractExplicitLayer init_state::S end -function RNNCell( - (in_dims, out_dims)::Pair{<:Int,<:Int}, - activation=tanh; - bias::Bool=true, - init_bias=zeros32, - init_weight=glorot_uniform, - init_state=ones32, -) - return RNNCell{bias,typeof(activation),typeof(init_bias),typeof(init_weight),typeof(init_state)}( - activation, in_dims, out_dims, init_bias, init_weight, init_state - ) +function RNNCell((in_dims, out_dims)::Pair{<:Int, <:Int}, + activation = tanh; + bias::Bool = true, + init_bias = zeros32, + init_weight = glorot_uniform, + init_state = ones32) + return RNNCell{bias, typeof(activation), typeof(init_bias), typeof(init_weight), + typeof(init_state)}(activation, in_dims, out_dims, init_bias, + init_weight, init_state) end function initialparameters(rng::AbstractRNG, rnn::RNNCell{bias}) where {bias} - ps = ( - weight_ih=rnn.init_weight(rng, rnn.out_dims, rnn.in_dims), - weight_hh=rnn.init_weight(rng, rnn.out_dims, rnn.out_dims), - ) + ps = (weight_ih = rnn.init_weight(rng, rnn.out_dims, rnn.in_dims), + weight_hh = rnn.init_weight(rng, rnn.out_dims, rnn.out_dims)) if bias - ps = merge(ps, (bias=rnn.init_bias(rng, rnn.out_dims),)) + ps = merge(ps, (bias = rnn.init_bias(rng, rnn.out_dims),)) end return ps end @@ -71,40 +67,43 @@ end function initialstates(rng::AbstractRNG, ::RNNCell) # FIXME: Take PRNGs seriously randn(rng, 1) - return (rng=replicate(rng),) + return (rng = replicate(rng),) end -function (rnn::RNNCell)(x::AbstractMatrix, ps::Union{ComponentArray,NamedTuple}, st::NamedTuple) +function (rnn::RNNCell)(x::AbstractMatrix, ps::Union{ComponentArray, NamedTuple}, + st::NamedTuple) rng = replicate(st.rng) @set! st.rng = rng hidden_state = rnn.init_state(rng, rnn.out_dims, size(x, 2)) return rnn((x, hidden_state), ps, st) end -function (rnn::RNNCell{true})( - (x, hidden_state)::Tuple{<:AbstractMatrix,<:AbstractMatrix}, ps::Union{ComponentArray,NamedTuple}, st::NamedTuple -) +function (rnn::RNNCell{true})((x, hidden_state)::Tuple{<:AbstractMatrix, <:AbstractMatrix}, + ps::Union{ComponentArray, NamedTuple}, st::NamedTuple) h_new = rnn.activation.(ps.weight_ih * x .+ ps.weight_hh * hidden_state .+ ps.bias) return h_new, st end -function (rnn::RNNCell{true,typeof(identity)})( - (x, hidden_state)::Tuple{<:AbstractMatrix,<:AbstractMatrix}, ps::Union{ComponentArray,NamedTuple}, st::NamedTuple -) +function (rnn::RNNCell{true, typeof(identity)})((x, + hidden_state)::Tuple{<:AbstractMatrix, + <:AbstractMatrix}, + ps::Union{ComponentArray, NamedTuple}, + st::NamedTuple) h_new = ps.weight_ih * x .+ ps.weight_hh * hidden_state .+ ps.bias return h_new, st end -function (rnn::RNNCell{false})( - (x, hidden_state)::Tuple{<:AbstractMatrix,<:AbstractMatrix}, ps::Union{ComponentArray,NamedTuple}, st::NamedTuple -) +function (rnn::RNNCell{false})((x, hidden_state)::Tuple{<:AbstractMatrix, <:AbstractMatrix}, + ps::Union{ComponentArray, NamedTuple}, st::NamedTuple) h_new = rnn.activation.(ps.weight_ih * x .+ ps.weight_hh * hidden_state) return h_new, st end -function (rnn::RNNCell{false,typeof(identity)})( - (x, hidden_state)::Tuple{<:AbstractMatrix,<:AbstractMatrix}, ps::Union{ComponentArray,NamedTuple}, st::NamedTuple -) +function (rnn::RNNCell{false, typeof(identity)})((x, + hidden_state)::Tuple{<:AbstractMatrix, + <:AbstractMatrix}, + ps::Union{ComponentArray, NamedTuple}, + st::NamedTuple) h_new = ps.weight_ih * x .+ ps.weight_hh * hidden_state return h_new, st end @@ -163,7 +162,7 @@ Long Short-Term (LSTM) Cell * `rng`: Controls the randomness (if any) in the initial state generation """ -struct LSTMCell{B,W,S} <: AbstractExplicitLayer +struct LSTMCell{B, W, S} <: AbstractExplicitLayer in_dims::Int out_dims::Int init_bias::B @@ -171,31 +170,36 @@ struct LSTMCell{B,W,S} <: AbstractExplicitLayer init_state::S end -function LSTMCell( - (in_dims, out_dims)::Pair{<:Int,<:Int}; - init_weight::Tuple{Function,Function,Function,Function}=( - glorot_uniform, glorot_uniform, glorot_uniform, glorot_uniform - ), - init_bias::Tuple{Function,Function,Function,Function}=(zeros32, zeros32, ones32, zeros32), - init_state::Function=zeros32, -) +function LSTMCell((in_dims, out_dims)::Pair{<:Int, <:Int}; + init_weight::Tuple{Function, Function, Function, Function} = (glorot_uniform, + glorot_uniform, + glorot_uniform, + glorot_uniform), + init_bias::Tuple{Function, Function, Function, Function} = (zeros32, + zeros32, + ones32, + zeros32), + init_state::Function = zeros32) return LSTMCell(in_dims, out_dims, init_bias, init_weight, init_state) end function initialparameters(rng::AbstractRNG, lstm::LSTMCell) - weight_i = vcat([init_weight(rng, lstm.out_dims, lstm.in_dims) for init_weight in lstm.init_weight]...) - weight_h = vcat([init_weight(rng, lstm.out_dims, lstm.out_dims) for init_weight in lstm.init_weight]...) + weight_i = vcat([init_weight(rng, lstm.out_dims, lstm.in_dims) + for init_weight in lstm.init_weight]...) + weight_h = vcat([init_weight(rng, lstm.out_dims, lstm.out_dims) + for init_weight in lstm.init_weight]...) bias = vcat([init_bias(rng, lstm.out_dims, 1) for init_bias in lstm.init_bias]...) - return (weight_i=weight_i, weight_h=weight_h, bias=bias) + return (weight_i = weight_i, weight_h = weight_h, bias = bias) end function initialstates(rng::AbstractRNG, ::LSTMCell) # FIXME: Take PRNGs seriously randn(rng, 1) - return (rng=replicate(rng),) + return (rng = replicate(rng),) end -function (lstm::LSTMCell)(x::AbstractMatrix, ps::Union{ComponentArray,NamedTuple}, st::NamedTuple) +function (lstm::LSTMCell)(x::AbstractMatrix, ps::Union{ComponentArray, NamedTuple}, + st::NamedTuple) rng = replicate(st.rng) @set! st.rng = rng hidden_state = lstm.init_state(rng, lstm.out_dims, size(x, 2)) @@ -203,11 +207,11 @@ function (lstm::LSTMCell)(x::AbstractMatrix, ps::Union{ComponentArray,NamedTuple return lstm((x, hidden_state, memory), ps, st) end -function (lstm::LSTMCell)( - (x, hidden_state, memory)::Tuple{<:AbstractMatrix,<:AbstractMatrix,<:AbstractMatrix}, - ps::Union{ComponentArray,NamedTuple}, - st::NamedTuple, -) +function (lstm::LSTMCell)((x, hidden_state, + memory)::Tuple{<:AbstractMatrix, <:AbstractMatrix, + <:AbstractMatrix}, + ps::Union{ComponentArray, NamedTuple}, + st::NamedTuple) g = ps.weight_i * x .+ ps.weight_h * hidden_state .+ ps.bias input, forget, cell, output = multigate(g, Val(4)) memory_new = @. sigmoid_fast(forget) * memory + sigmoid_fast(input) * tanh_fast(cell) @@ -259,7 +263,7 @@ Gated Recurrent Unit (GRU) Cell * `rng`: Controls the randomness (if any) in the initial state generation """ -struct GRUCell{W,B,S} <: AbstractExplicitLayer +struct GRUCell{W, B, S} <: AbstractExplicitLayer in_dims::Int out_dims::Int init_weight::W @@ -267,39 +271,42 @@ struct GRUCell{W,B,S} <: AbstractExplicitLayer init_state::S end -function GRUCell( - (in_dims, out_dims)::Pair{<:Int,<:Int}; - init_weight::Tuple{Function,Function,Function}=(glorot_uniform, glorot_uniform, glorot_uniform), - init_bias::Tuple{Function,Function,Function}=(zeros32, zeros32, zeros32), - init_state::Function=zeros32, -) +function GRUCell((in_dims, out_dims)::Pair{<:Int, <:Int}; + init_weight::Tuple{Function, Function, Function} = (glorot_uniform, + glorot_uniform, + glorot_uniform), + init_bias::Tuple{Function, Function, Function} = (zeros32, zeros32, + zeros32), + init_state::Function = zeros32) return GRUCell(in_dims, out_dims, init_weight, init_bias, init_state) end function initialparameters(rng::AbstractRNG, gru::GRUCell) - weight_i = vcat([init_weight(rng, gru.out_dims, gru.in_dims) for init_weight in gru.init_weight]...) - weight_h = vcat([init_weight(rng, gru.out_dims, gru.out_dims) for init_weight in gru.init_weight]...) + weight_i = vcat([init_weight(rng, gru.out_dims, gru.in_dims) + for init_weight in gru.init_weight]...) + weight_h = vcat([init_weight(rng, gru.out_dims, gru.out_dims) + for init_weight in gru.init_weight]...) bias_i = gru.init_bias[1](rng, gru.out_dims, 1) bias_h = vcat([init_bias(rng, gru.out_dims, 1) for init_bias in gru.init_bias]...) - return (weight_i=weight_i, weight_h=weight_h, bias_i=bias_i, bias_h=bias_h) + return (weight_i = weight_i, weight_h = weight_h, bias_i = bias_i, bias_h = bias_h) end function initialstates(rng::AbstractRNG, ::GRUCell) # FIXME: Take PRNGs seriously randn(rng, 1) - return (rng=replicate(rng),) + return (rng = replicate(rng),) end -function (gru::GRUCell)(x::AbstractMatrix, ps::Union{ComponentArray,NamedTuple}, st::NamedTuple) +function (gru::GRUCell)(x::AbstractMatrix, ps::Union{ComponentArray, NamedTuple}, + st::NamedTuple) rng = replicate(st.rng) @set! st.rng = rng hidden_state = gru.init_state(rng, gru.out_dims, size(x, 2)) return gru((x, hidden_state), ps, st) end -function (gru::GRUCell)( - (x, hidden_state)::Tuple{<:AbstractMatrix,<:AbstractMatrix}, ps::Union{ComponentArray,NamedTuple}, st::NamedTuple -) +function (gru::GRUCell)((x, hidden_state)::Tuple{<:AbstractMatrix, <:AbstractMatrix}, + ps::Union{ComponentArray, NamedTuple}, st::NamedTuple) gxs = multigate(ps.weight_i * x, Val(3)) ghbs = multigate(ps.weight_h * hidden_state .+ ps.bias_h, Val(3)) diff --git a/src/nnlib.jl b/src/nnlib.jl index b64eb5af3..f6e627cd5 100644 --- a/src/nnlib.jl +++ b/src/nnlib.jl @@ -1,20 +1,18 @@ ## TODO: Eventually we want to move all these functions and their adjoints to NNlib.jl # Normalization Implementation -@inline function update_statistics( - x::AbstractArray{T,N}, - running_mean::AbstractArray{T,N}, - running_var::AbstractArray{T,N}, - batchmean::AbstractArray{T,N}, - batchvar::AbstractArray{T,N}, - momentum::T, - reduce_dims, -) where {T,N} +@inline function update_statistics(x::AbstractArray{T, N}, + running_mean::AbstractArray{T, N}, + running_var::AbstractArray{T, N}, + batchmean::AbstractArray{T, N}, + batchvar::AbstractArray{T, N}, + momentum::T, + reduce_dims) where {T, N} sx = size(x) m = T(prod((sx[i] for i in reduce_dims))) if reduce_dims[end] != N - batchmean = mean(batchmean; dims=N) - batchvar = mean(batchvar; dims=N) + batchmean = mean(batchmean; dims = N) + batchvar = mean(batchvar; dims = N) end running_mean = @. (1 - momentum) * running_mean + momentum * batchmean running_var = @. (1 - momentum) * running_var + momentum * batchvar * (m / (m - one(m))) @@ -29,32 +27,32 @@ Performs BatchNorm/GroupNorm/InstanceNorm based on input configuration !!! note Detailed docs are WIP """ -@inline function normalization( - x::AbstractArray{T,N}, - running_mean::Union{Nothing,AbstractVector{T}}, - running_var::Union{Nothing,AbstractVector{T}}, - scale::Union{Nothing,AbstractVector{T}}, - bias::Union{Nothing,AbstractVector{T}}, - activation, - reduce_dims, - t::Val, - momentum::T=T(0.1), - epsilon::T=T(1e-5); - kwargs..., -) where {T,N} - x_norm, running_mean_, running_var_ = normalization_forward( - x, - reshape_into_proper_shape(running_mean, x), - reshape_into_proper_shape(running_var, x), - reshape_into_proper_shape(scale, x), - reshape_into_proper_shape(bias, x), - activation, - reduce_dims, - t, - momentum, - epsilon; - kwargs..., - ) +@inline function normalization(x::AbstractArray{T, N}, + running_mean::Union{Nothing, AbstractVector{T}}, + running_var::Union{Nothing, AbstractVector{T}}, + scale::Union{Nothing, AbstractVector{T}}, + bias::Union{Nothing, AbstractVector{T}}, + activation, + reduce_dims, + t::Val, + momentum::T = T(0.1), + epsilon::T = T(1e-5); + kwargs...) where {T, N} + x_norm, running_mean_, running_var_ = normalization_forward(x, + reshape_into_proper_shape(running_mean, + x), + reshape_into_proper_shape(running_var, + x), + reshape_into_proper_shape(scale, + x), + reshape_into_proper_shape(bias, + x), + activation, + reduce_dims, + t, + momentum, + epsilon; + kwargs...) return x_norm, safe_vec(running_mean_), safe_vec(running_var_) end @@ -126,18 +124,20 @@ end # Convolution @inline conv_wrapper(x, weight, cdims) = conv(x, weight, cdims) -@inline function conv_wrapper(x::SubArray{T,N,<:CuArray}, weight, cdims) where {T,N} +@inline function conv_wrapper(x::SubArray{T, N, <:CuArray}, weight, cdims) where {T, N} return conv(copy(x), weight, cdims) end # Dropout @inline _dropout_shape(s, ::Colon) = size(s) -@inline _dropout_shape(s, dims) = tuple((i ∉ dims ? 1 : si for (i, si) in enumerate(size(s)))...) +@inline function _dropout_shape(s, dims) + tuple((i ∉ dims ? 1 : si for (i, si) in enumerate(size(s)))...) +end ## TODO: Cache `1 / q` since we never need `q` @inline _dropout_kernel(y::T, p, q) where {T} = y > p ? T(1 / q) : T(0) -@inline function generate_dropout_mask(rng::AbstractRNG, x, p; dims=:) +@inline function generate_dropout_mask(rng::AbstractRNG, x, p; dims = :) realfptype = float(real(eltype(x))) y = rand!(rng, similar(x, realfptype, _dropout_shape(x, dims))) y .= _dropout_kernel.(y, p, 1 - p) @@ -188,19 +188,24 @@ end stride = insize .÷ outsize k = insize .- (outsize .- 1) .* stride pad = 0 - return PoolDims(x, k; padding=pad, stride=stride) + return PoolDims(x, k; padding = pad, stride = stride) end # CUDNN Constants const cudnnValidActivationTypes = Union{ - typeof(tanh),typeof(sigmoid),typeof(relu),typeof(elu),typeof(tanh_fast),typeof(sigmoid_fast) -} + typeof(tanh), typeof(sigmoid), typeof(relu), + typeof(elu), typeof(tanh_fast), typeof(sigmoid_fast) + } # Activation Functions ## I think this is handled by NNlibCUDA. But currently leaving here for ## benchmarking larger models -getCUDNNActivationMode(::Union{typeof(tanh),typeof(tanh_fast)}) = CUDNN.CUDNN_ACTIVATION_TANH -getCUDNNActivationMode(::Union{typeof(sigmoid),typeof(sigmoid_fast)}) = CUDNN.CUDNN_ACTIVATION_SIGMOID +function getCUDNNActivationMode(::Union{typeof(tanh), typeof(tanh_fast)}) + CUDNN.CUDNN_ACTIVATION_TANH +end +function getCUDNNActivationMode(::Union{typeof(sigmoid), typeof(sigmoid_fast)}) + CUDNN.CUDNN_ACTIVATION_SIGMOID +end getCUDNNActivationMode(::Union{typeof(relu)}) = CUDNN.CUDNN_ACTIVATION_RELU getCUDNNActivationMode(::Union{typeof(elu)}) = CUDNN.CUDNN_ACTIVATION_ELU @@ -211,7 +216,7 @@ Apply the function `f` on `x` elementwise, i.e. `f.(x)`. Dispatches to CUDNN if """ @inline applyactivation(f::Function, x::AbstractArray) = f.(x) @inline function applyactivation(f::cudnnValidActivationTypes, x::CuArray{<:CUDNNFloat}) - return CUDNN.cudnnActivationForward(x; mode=getCUDNNActivationMode(f)) + return CUDNN.cudnnActivationForward(x; mode = getCUDNNActivationMode(f)) end @inline applyactivation(::typeof(identity), x::AbstractArray) = x @@ -220,12 +225,14 @@ end sx = size(x) sΔ = size(Δ) sx == sΔ && return Δ - return sum(Δ; dims=findall(sx .!= sΔ)) + return sum(Δ; dims = findall(sx .!= sΔ)) end @inline isvalidtensorop(x1, x2) = false -@inline function isvalidtensorop(x1::CuArray{N,T}, x2::CuArray{N,T}) where {N,T<:CUDNNFloat} - return ndims(x1) <= 5 && (all(size(x2, i) == size(x1, i) || size(x2, i) == 1 for i in 1:ndims(x2))) +@inline function isvalidtensorop(x1::CuArray{N, T}, + x2::CuArray{N, T}) where {N, T <: CUDNNFloat} + return ndims(x1) <= 5 && + (all(size(x2, i) == size(x1, i) || size(x2, i) == 1 for i in 1:ndims(x2))) end """ @@ -236,10 +243,12 @@ Computes `x .+ y`. Dispatches to CUDNN if possible @inline elementwise_add(x, y) = x .+ y @inline function elementwise_add(x::CuArray, y::CuArray) !isvalidtensorop(x, y) && return x .+ y - return cudnnOpTensorWithDefaults(x, y; op=CUDNN.CUDNN_OP_TENSOR_ADD) + return cudnnOpTensorWithDefaults(x, y; op = CUDNN.CUDNN_OP_TENSOR_ADD) end -@inline elementwise_add_pullback(x, y, Δ) = broadcast_shape_pullback(x, Δ), broadcast_shape_pullback(y, Δ) +@inline function elementwise_add_pullback(x, y, Δ) + broadcast_shape_pullback(x, Δ), broadcast_shape_pullback(y, Δ) +end """ elementwise_mul(x, y) @@ -249,52 +258,52 @@ Computes `x .* y`. Dispatches to CUDNN if possible @inline elementwise_mul(x, y) = x .* y @inline function elementwise_mul(x::CuArray, y::CuArray) !isvalidtensorop(x, y) && return x .* y - return cudnnOpTensorWithDefaults(x, y; op=CUDNN.CUDNN_OP_TENSOR_MUL) + return cudnnOpTensorWithDefaults(x, y; op = CUDNN.CUDNN_OP_TENSOR_MUL) end @inline function elementwise_mul_pullback(x, y, Δ) - return broadcast_shape_pullback(x, elementwise_mul(Δ, y)), broadcast_shape_pullback(y, elementwise_mul(Δ, x)) + return broadcast_shape_pullback(x, elementwise_mul(Δ, y)), + broadcast_shape_pullback(y, elementwise_mul(Δ, x)) end # CUDNN Helpers -function cudnnOpTensorWithDefaults( - x1, - x2; - y=similar(x1), - op::CUDNN.cudnnOpTensorOp_t=CUDNN.CUDNN_OP_TENSOR_ADD, - compType::DataType=(eltype(x1) <: Float64 ? Float64 : Float32), - nanOpt::CUDNN.cudnnNanPropagation_t=CUDNN.CUDNN_NOT_PROPAGATE_NAN, - opTensorDesc::CUDNN.cudnnOpTensorDescriptor=CUDNN.cudnnOpTensorDescriptor( - op, CUDNN.cudnnDataType(compType), nanOpt - ), - alpha1::Real=1, - alpha2::Real=1, - beta::Real=0, - x1Desc::CUDNN.cudnnTensorDescriptor=CUDNN.cudnnTensorDescriptor(x1), - x2Desc::CUDNN.cudnnTensorDescriptor=CUDNN.cudnnTensorDescriptor(x2), - yDesc::CUDNN.cudnnTensorDescriptor=CUDNN.cudnnTensorDescriptor(y), -) +function cudnnOpTensorWithDefaults(x1, + x2; + y = similar(x1), + op::CUDNN.cudnnOpTensorOp_t = CUDNN.CUDNN_OP_TENSOR_ADD, + compType::DataType = (eltype(x1) <: Float64 ? Float64 : + Float32), + nanOpt::CUDNN.cudnnNanPropagation_t = CUDNN.CUDNN_NOT_PROPAGATE_NAN, + opTensorDesc::CUDNN.cudnnOpTensorDescriptor = CUDNN.cudnnOpTensorDescriptor(op, + CUDNN.cudnnDataType(compType), + nanOpt), + alpha1::Real = 1, + alpha2::Real = 1, + beta::Real = 0, + x1Desc::CUDNN.cudnnTensorDescriptor = CUDNN.cudnnTensorDescriptor(x1), + x2Desc::CUDNN.cudnnTensorDescriptor = CUDNN.cudnnTensorDescriptor(x2), + yDesc::CUDNN.cudnnTensorDescriptor = CUDNN.cudnnTensorDescriptor(y)) T = eltype(x1) alpha1, alpha2, beta = CUDNN.scalingParameter.((T,), (alpha1, alpha2, beta)) - return CUDNN.cudnnOpTensorAD(x1, x2; opTensorDesc, alpha1, x1Desc, alpha2, x2Desc, beta, yDesc, y) + return CUDNN.cudnnOpTensorAD(x1, x2; opTensorDesc, alpha1, x1Desc, alpha2, x2Desc, beta, + yDesc, y) end -function cudnnActivationBackward(y::CuArray{T}, Δ::CuArray{T}, x::CuArray{T}; mode) where {T} +function cudnnActivationBackward(y::CuArray{T}, Δ::CuArray{T}, x::CuArray{T}; + mode) where {T} Δx = similar(x) desc = CUDNN.cudnnActivationDescriptor(mode, CUDNN.CUDNN_NOT_PROPAGATE_NAN, Cdouble(1)) - CUDNN.cudnnActivationBackward( - CUDNN.handle(), - desc, - CUDNN.scalingParameter(T, 1), - CUDNN.cudnnTensorDescriptor(y), - y, - CUDNN.cudnnTensorDescriptor(Δ), - Δ, - CUDNN.cudnnTensorDescriptor(x), - x, - CUDNN.scalingParameter(T, 0), - CUDNN.cudnnTensorDescriptor(Δx), - Δx, - ) + CUDNN.cudnnActivationBackward(CUDNN.handle(), + desc, + CUDNN.scalingParameter(T, 1), + CUDNN.cudnnTensorDescriptor(y), + y, + CUDNN.cudnnTensorDescriptor(Δ), + Δ, + CUDNN.cudnnTensorDescriptor(x), + x, + CUDNN.scalingParameter(T, 0), + CUDNN.cudnnTensorDescriptor(Δx), + Δx) return Δx -end \ No newline at end of file +end diff --git a/src/transform.jl b/src/transform.jl index 842382e63..04bd8e34e 100644 --- a/src/transform.jl +++ b/src/transform.jl @@ -24,22 +24,21 @@ transform(::T) where {T} = error("Transformation for type $T not implemented") transform(model::Flux.Chain) = Chain(transform.(model.layers)...) function transform(model::Flux.BatchNorm) - return BatchNorm( - model.chs, model.λ; affine=model.affine, track_stats=model.track_stats, epsilon=model.ϵ, momentum=model.momentum - ) + return BatchNorm(model.chs, model.λ; affine = model.affine, + track_stats = model.track_stats, epsilon = model.ϵ, + momentum = model.momentum) end function transform(model::Flux.Conv) - return Conv( - size(model.weight)[1:(end - 2)], - size(model.weight, ndims(model.weight) - 1) * model.groups => size(model.weight, ndims(model.weight)), - model.σ; - stride=model.stride, - pad=model.pad, - bias=model.bias isa Bool ? model.bias : !(model.bias isa Flux.Zeros), - dilation=model.dilation, - groups=model.groups, - ) + return Conv(size(model.weight)[1:(end - 2)], + size(model.weight, ndims(model.weight) - 1) * model.groups => size(model.weight, + ndims(model.weight)), + model.σ; + stride = model.stride, + pad = model.pad, + bias = model.bias isa Bool ? model.bias : !(model.bias isa Flux.Zeros), + dilation = model.dilation, + groups = model.groups) end function transform(model::Flux.SkipConnection) @@ -79,7 +78,7 @@ function transform(model::Flux.Parallel) end function transform(d::Flux.Dropout) - return Dropout(Float32(d.p); dims=d.dims) + return Dropout(Float32(d.p); dims = d.dims) end transform(::typeof(identity)) = NoOpLayer() diff --git a/src/utils.jl b/src/utils.jl index c275c49b2..2b138b3de 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -5,7 +5,7 @@ nfan() = 1, 1 # fan_in, fan_out nfan(n) = 1, n # A vector is treated as a n×1 matrix nfan(n_out, n_in) = n_in, n_out # In case of Dense kernels: arranged as matrices nfan(dims::Tuple) = nfan(dims...) -nfan(dims...) = prod(dims[1:end-2]) .* (dims[end-1], dims[end]) # In case of convolution kernels +nfan(dims...) = prod(dims[1:(end - 2)]) .* (dims[end - 1], dims[end]) # In case of convolution kernels # Neural Network Initialization ## NOTE: Would be great if these could be moved into its own package and NN frameworks @@ -36,7 +36,7 @@ Return an `Array{Float32}` of the given `size` containing random numbers drawn f [1] Glorot, Xavier, and Yoshua Bengio. "Understanding the difficulty of training deep feedforward neural networks." _Proceedings of the thirteenth international conference on artificial intelligence and statistics_. 2010. """ -function glorot_uniform(rng::AbstractRNG, dims::Integer...; gain::Real=1) +function glorot_uniform(rng::AbstractRNG, dims::Integer...; gain::Real = 1) scale = Float32(gain) * sqrt(24.0f0 / sum(nfan(dims...))) return (rand(rng, Float32, dims...) .- 0.5f0) .* scale end @@ -50,7 +50,7 @@ Return an `Array{Float32}` of the given `size` containing random numbers drawn f [1] Glorot, Xavier, and Yoshua Bengio. "Understanding the difficulty of training deep feedforward neural networks." _Proceedings of the thirteenth international conference on artificial intelligence and statistics_. 2010. """ -function glorot_normal(rng::AbstractRNG, dims::Integer...; gain::Real=1) +function glorot_normal(rng::AbstractRNG, dims::Integer...; gain::Real = 1) std = Float32(gain) * sqrt(2.0f0 / sum(nfan(dims...))) return randn(rng, Float32, dims...) .* std end @@ -65,53 +65,64 @@ replicate(rng::CUDA.RNG) = deepcopy(rng) @inline istraining(st::NamedTuple) = istraining(st.training) # Linear Algebra -@inline _norm(x; dims=Colon()) = sqrt.(sum(abs2, x; dims=dims)) -@inline _norm_except(x::AbstractArray{T,N}, except_dim=N) where {T,N} = _norm(x; dims=filter(i -> i != except_dim, 1:N)) +@inline _norm(x; dims = Colon()) = sqrt.(sum(abs2, x; dims = dims)) +@inline function _norm_except(x::AbstractArray{T, N}, except_dim = N) where {T, N} + _norm(x; dims = filter(i -> i != except_dim, 1:N)) +end # Convolution -function convfilter(rng::AbstractRNG, filter::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}; - init = glorot_uniform, groups = 1) where N +function convfilter(rng::AbstractRNG, filter::NTuple{N, Integer}, + ch::Pair{<:Integer, <:Integer}; + init = glorot_uniform, groups = 1) where {N} cin, cout = ch - @assert cin % groups == 0 "Input channel dimension must be divisible by groups." - @assert cout % groups == 0 "Output channel dimension must be divisible by groups." - return init(rng, filter..., cin÷groups, cout) + @assert cin % groups==0 "Input channel dimension must be divisible by groups." + @assert cout % groups==0 "Output channel dimension must be divisible by groups." + return init(rng, filter..., cin ÷ groups, cout) end expand(N, i::Tuple) = i expand(N, i::Integer) = ntuple(_ -> i, N) _maybetuple_string(pad) = string(pad) -_maybetuple_string(pad::Tuple) = all(==(pad[1]), pad) ? string(pad[1]) : string(pad) +_maybetuple_string(pad::Tuple) = all(==(pad[1]), pad) ? string(pad[1]) : string(pad) # Padding struct SamePad end -calc_padding(lt, pad, k::NTuple{N,T}, dilation, stride) where {T,N}= expand(Val(2*N), pad) +function calc_padding(lt, pad, k::NTuple{N, T}, dilation, stride) where {T, N} + expand(Val(2 * N), pad) +end -function calc_padding(lt, ::SamePad, k::NTuple{N,T}, dilation, stride) where {N,T} +function calc_padding(lt, ::SamePad, k::NTuple{N, T}, dilation, stride) where {N, T} # Ref: "A guide to convolution arithmetic for deep learning" https://arxiv.org/abs/1603.07285 # Effective kernel size, including dilation k_eff = @. k + (k - 1) * (dilation - 1) # How much total padding needs to be applied? pad_amt = @. k_eff - 1 # In case amount of padding is odd we need to apply different amounts to each side. - return Tuple(mapfoldl(i -> [cld(i, 2), fld(i,2)], vcat, pad_amt)) + return Tuple(mapfoldl(i -> [cld(i, 2), fld(i, 2)], vcat, pad_amt)) end # Handling ComponentArrays ## NOTE: We should probably upsteam some of these -Base.zero(c::ComponentArray{T,N,<:CuArray{T}}) where {T,N} = ComponentArray(zero(getdata(c)), getaxes(c)) +function Base.zero(c::ComponentArray{T, N, <:CuArray{T}}) where {T, N} + ComponentArray(zero(getdata(c)), getaxes(c)) +end -Base.vec(c::ComponentArray{T,N,<:CuArray{T}}) where {T,N} = getdata(c) +Base.vec(c::ComponentArray{T, N, <:CuArray{T}}) where {T, N} = getdata(c) -Base.:-(x::ComponentArray{T,N,<:CuArray{T}}) where {T,N} = ComponentArray(-getdata(x), getaxes(x)) +function Base.:-(x::ComponentArray{T, N, <:CuArray{T}}) where {T, N} + ComponentArray(-getdata(x), getaxes(x)) +end -function Base.similar(c::ComponentArray{T,N,<:CuArray{T}}, l::Vararg{Union{Integer,AbstractUnitRange}}) where {T,N} +function Base.similar(c::ComponentArray{T, N, <:CuArray{T}}, + l::Vararg{Union{Integer, AbstractUnitRange}}) where {T, N} return similar(getdata(c), l) end function Functors.functor(::Type{<:ComponentArray}, c) - return NamedTuple{propertynames(c)}(getproperty.((c,), propertynames(c))), ComponentArray + return NamedTuple{propertynames(c)}(getproperty.((c,), propertynames(c))), + ComponentArray end # Updating a monolithic vector is way faster than chunks of smaller ones @@ -130,7 +141,8 @@ end function ComponentArrays.make_carray_args(nt::NamedTuple) data, ax = ComponentArrays.make_carray_args(Vector, nt) - data = length(data) == 0 ? Float32[] : (length(data)==1 ? [data[1]] : reduce(vcat, data)) + data = length(data) == 0 ? Float32[] : + (length(data) == 1 ? [data[1]] : reduce(vcat, data)) return (data, ax) end @@ -144,7 +156,7 @@ end ComponentArrays.recursive_length(nt::NamedTuple{(), Tuple{}}) = 0 # Return Nothing if field not present -function safe_getproperty(x::Union{ComponentArray,NamedTuple}, k::Symbol) +function safe_getproperty(x::Union{ComponentArray, NamedTuple}, k::Symbol) k ∈ propertynames(x) && return getproperty(x, k) return nothing end @@ -157,7 +169,8 @@ get_typename(::T) where {T} = Base.typename(T).wrapper @inline @generated safe_vec(x::T) where {T} = hasmethod(vec, (T,)) ? :(vec(x)) : :x -@inline @inbounds function get_reshape_dims(sx::NTuple{N,<:Int}, ly::Int)::typeof(sx) where {N} +@inline @inbounds function get_reshape_dims(sx::NTuple{N, <:Int}, + ly::Int)::typeof(sx) where {N} if ly == sx[N - 1] return ntuple(i -> i == N - 1 ? ly : 1, N) elseif N > 2 && ly == sx[N - 1] * sx[N - 2] @@ -168,7 +181,9 @@ get_typename(::T) where {T} = Base.typename(T).wrapper end @inline reshape_into_proper_shape(x::Nothing, y)::Nothing = x -@inline reshape_into_proper_shape(x, y)::typeof(y) = reshape(x, get_reshape_dims(size(y), length(x))) +@inline reshape_into_proper_shape(x, y)::typeof(y) = reshape(x, + get_reshape_dims(size(y), + length(x))) # RNN Utilities @inline gate(h::Int, n::Int) = (1:h) .+ h * (n - 1) @@ -180,4 +195,4 @@ end Split up `x` into `N` equally sized chunks (along dimension `1`). """ -@inline multigate(x::AbstractArray, ::Val{N}) where N = gate.((x,), size(x, 1) ÷ N, 1:N) +@inline multigate(x::AbstractArray, ::Val{N}) where {N} = gate.((x,), size(x, 1) ÷ N, 1:N) diff --git a/test/layers/basic.jl b/test/layers/basic.jl index c20e6f618..17e0e5e53 100644 --- a/test/layers/basic.jl +++ b/test/layers/basic.jl @@ -14,8 +14,9 @@ Random.seed!(rng, 0) @test size(layer(x, ps, st)[1]) == (2, 3, 3) @test_call layer(x, ps, st) - @test_opt target_modules = (Lux,) layer(x, ps, st) - test_gradient_correctness_fdm(x -> sum(layer(x, ps, st)[1]), x; atol=1f-3, rtol=1f-3) + @test_opt target_modules=(Lux,) layer(x, ps, st) + test_gradient_correctness_fdm(x -> sum(layer(x, ps, st)[1]), x; atol = 1.0f-3, + rtol = 1.0f-3) end @testset "Flatten Layer" begin @@ -26,8 +27,9 @@ Random.seed!(rng, 0) @test size(layer(x, ps, st)[1]) == (18, 2) @test_call layer(x, ps, st) - @test_opt target_modules = (Lux,) layer(x, ps, st) - test_gradient_correctness_fdm(x -> sum(layer(x, ps, st)[1]), x; atol=1f-3, rtol=1f-3) + @test_opt target_modules=(Lux,) layer(x, ps, st) + test_gradient_correctness_fdm(x -> sum(layer(x, ps, st)[1]), x; atol = 1.0f-3, + rtol = 1.0f-3) end @testset "NoOpLayer" begin @@ -38,10 +40,11 @@ Random.seed!(rng, 0) @test layer(x, ps, st)[1] == x @test_call layer(x, ps, st) - @test_opt target_modules = (Lux,) layer(x, ps, st) + @test_opt target_modules=(Lux,) layer(x, ps, st) x = randn(rng, 6, 3) - test_gradient_correctness_fdm(x -> sum(layer(x, ps, st)[1]), x; atol=1f-3, rtol=1f-3) + test_gradient_correctness_fdm(x -> sum(layer(x, ps, st)[1]), x; atol = 1.0f-3, + rtol = 1.0f-3) end @testset "SelectDim Layer" begin @@ -52,8 +55,9 @@ Random.seed!(rng, 0) @test size(layer(x, ps, st)[1]) == (6, 4, 2) @test_call layer(x, ps, st) - @test_opt target_modules = (Lux,) layer(x, ps, st) broken=true - test_gradient_correctness_fdm(x -> sum(layer(x, ps, st)[1]), x; atol=1f-3, rtol=1f-3) + @test_opt target_modules=(Lux,) layer(x, ps, st) broken=true + test_gradient_correctness_fdm(x -> sum(layer(x, ps, st)[1]), x; atol = 1.0f-3, + rtol = 1.0f-3) end @testset "WrappedFunction" begin @@ -64,8 +68,9 @@ Random.seed!(rng, 0) @test layer(x, ps, st)[1] == x .* x @test_call layer(x, ps, st) - @test_opt target_modules = (Lux,) layer(x, ps, st) - test_gradient_correctness_fdm(x -> sum(layer(x, ps, st)[1]), x; atol=1f-3, rtol=1f-3) + @test_opt target_modules=(Lux,) layer(x, ps, st) + test_gradient_correctness_fdm(x -> sum(layer(x, ps, st)[1]), x; atol = 1.0f-3, + rtol = 1.0f-3) end @testset "ActivationFunction" begin @@ -76,8 +81,9 @@ Random.seed!(rng, 0) @test layer(x, ps, st)[1] == tanh.(x) @test_call layer(x, ps, st) - @test_opt target_modules = (Lux,) layer(x, ps, st) - test_gradient_correctness_fdm(x -> sum(layer(x, ps, st)[1]), x; atol=1f-3, rtol=1f-3) + @test_opt target_modules=(Lux,) layer(x, ps, st) + test_gradient_correctness_fdm(x -> sum(layer(x, ps, st)[1]), x; atol = 1.0f-3, + rtol = 1.0f-3) end end @@ -91,8 +97,9 @@ end @test layer(x, ps, st)[1] == x @test_call layer(x, ps, st) - @test_opt target_modules = (Lux,) layer(x, ps, st) - test_gradient_correctness_fdm(x -> sum(layer(x, ps, st)[1]), x; atol=1f-3, rtol=1f-3) + @test_opt target_modules=(Lux,) layer(x, ps, st) + test_gradient_correctness_fdm(x -> sum(layer(x, ps, st)[1]), x; atol = 1.0f-3, + rtol = 1.0f-3) end @testset "concat size" begin @@ -103,8 +110,9 @@ end @test size(layer(x, ps, st)[1]) == (10, 4) @test_call layer(x, ps, st) - @test_opt target_modules = (Lux,) layer(x, ps, st) - test_gradient_correctness_fdm((x, ps) -> sum(layer(x, ps, st)[1]), x, ps; atol=1f-3, rtol=1f-3) + @test_opt target_modules=(Lux,) layer(x, ps, st) + test_gradient_correctness_fdm((x, ps) -> sum(layer(x, ps, st)[1]), x, ps; + atol = 1.0f-3, rtol = 1.0f-3) end end @@ -117,20 +125,22 @@ end @test layer(x, ps, st)[1] == x @test_call layer(x, ps, st) - @test_opt target_modules = (Lux,) layer(x, ps, st) - test_gradient_correctness_fdm(x -> sum(layer(x, ps, st)[1]), x; atol=1f-3, rtol=1f-3) + @test_opt target_modules=(Lux,) layer(x, ps, st) + test_gradient_correctness_fdm(x -> sum(layer(x, ps, st)[1]), x; atol = 1.0f-3, + rtol = 1.0f-3) end @testset "concat size" begin - layer = Parallel((a, b) -> cat(a, b; dims=2), Dense(10, 10), NoOpLayer()) + layer = Parallel((a, b) -> cat(a, b; dims = 2), Dense(10, 10), NoOpLayer()) println(layer) ps, st = Lux.setup(rng, layer) x = randn(rng, 10, 2) @test size(layer(x, ps, st)[1]) == (10, 4) @test_call layer(x, ps, st) - @test_opt target_modules = (Lux,) layer(x, ps, st) - test_gradient_correctness_fdm(x -> sum(layer(x, ps, st)[1]), x; atol=1f-3, rtol=1f-3) + @test_opt target_modules=(Lux,) layer(x, ps, st) + test_gradient_correctness_fdm(x -> sum(layer(x, ps, st)[1]), x; atol = 1.0f-3, + rtol = 1.0f-3) layer = Parallel(hcat, Dense(10, 10), NoOpLayer()) println(layer) @@ -138,8 +148,9 @@ end @test size(layer(x, ps, st)[1]) == (10, 4) @test_call layer(x, ps, st) - @test_opt target_modules = (Lux,) layer(x, ps, st) - test_gradient_correctness_fdm(x -> sum(layer(x, ps, st)[1]), x; atol=1f-3, rtol=1f-3) + @test_opt target_modules=(Lux,) layer(x, ps, st) + test_gradient_correctness_fdm(x -> sum(layer(x, ps, st)[1]), x; atol = 1.0f-3, + rtol = 1.0f-3) end @testset "vararg input" begin @@ -150,14 +161,16 @@ end @test size(layer(x, ps, st)[1]) == (2, 1) @test_call layer(x, ps, st) - @test_opt target_modules = (Lux,) layer(x, ps, st) - test_gradient_correctness_fdm((x, ps) -> sum(layer(x, ps, st)[1]), x, ps; atol=1f-3, rtol=1f-3) + @test_opt target_modules=(Lux,) layer(x, ps, st) + test_gradient_correctness_fdm((x, ps) -> sum(layer(x, ps, st)[1]), x, ps; + atol = 1.0f-3, rtol = 1.0f-3) end @testset "connection is called once" begin CNT = Ref(0) f_cnt = (x...) -> (CNT[] += 1; +(x...)) - layer = Parallel(f_cnt, WrappedFunction(sin), WrappedFunction(cos), WrappedFunction(tan)) + layer = Parallel(f_cnt, WrappedFunction(sin), WrappedFunction(cos), + WrappedFunction(tan)) ps, st = Lux.setup(rng, layer) Lux.apply(layer, 1, ps, st) @test CNT[] == 1 @@ -171,12 +184,12 @@ end # Ref https://github.com/FluxML/Flux.jl/issues/1673 @testset "Input domain" begin struct Input - x + x::Any end struct L1 <: Lux.AbstractExplicitLayer end (::L1)(x, ps, st) = (ps.x * x, st) - Lux.initialparameters(rng::AbstractRNG, ::L1) = (x=randn(rng, Float32, 3, 3),) + Lux.initialparameters(rng::AbstractRNG, ::L1) = (x = randn(rng, Float32, 3, 3),) Base.:*(a::AbstractArray, b::Input) = a * b.x par = Parallel(+, L1(), L1()) @@ -186,12 +199,15 @@ end ip2 = Input(rand(Float32, 3, 3)) @test par(ip, ps, st)[1] ≈ - par.layers[1](ip.x, ps.layer_1, st.layer_1)[1] + par.layers[2](ip.x, ps.layer_2, st.layer_2)[1] + par.layers[1](ip.x, ps.layer_1, st.layer_1)[1] + + par.layers[2](ip.x, ps.layer_2, st.layer_2)[1] @test par((ip, ip2), ps, st)[1] ≈ - par.layers[1](ip.x, ps.layer_1, st.layer_1)[1] + par.layers[2](ip2.x, ps.layer_2, st.layer_2)[1] + par.layers[1](ip.x, ps.layer_1, st.layer_1)[1] + + par.layers[2](ip2.x, ps.layer_2, st.layer_2)[1] gs = gradient((p, x...) -> sum(par(x, p, st)[1]), ps, ip, ip2) gs_reg = gradient(ps, ip, ip2) do p, x, y - sum(par.layers[1](x.x, p.layer_1, st.layer_1)[1] + par.layers[2](y.x, p.layer_2, st.layer_2)[1]) + sum(par.layers[1](x.x, p.layer_1, st.layer_1)[1] + + par.layers[2](y.x, p.layer_2, st.layer_2)[1]) end @test gs[1] ≈ gs_reg[1] @@ -210,7 +226,7 @@ end @test size(ps.bias) == (100, 1) @test layer.activation == identity - layer = Dense(10, 100, relu; bias=false) + layer = Dense(10, 100, relu; bias = false) ps, st = Lux.setup(rng, layer) @test !haskey(ps, :bias) @@ -227,27 +243,27 @@ end @testset "zeros" begin @test begin - layer = Dense(10, 1, identity; init_weight=ones) + layer = Dense(10, 1, identity; init_weight = ones) first(Lux.apply(layer, ones(10, 1), Lux.setup(rng, layer)...)) end == 10 * ones(1, 1) @test begin - layer = Dense(10, 1, identity; init_weight=ones) + layer = Dense(10, 1, identity; init_weight = ones) first(Lux.apply(layer, ones(10, 2), Lux.setup(rng, layer)...)) end == 10 * ones(1, 2) @test begin - layer = Dense(10, 2, identity; init_weight=ones) + layer = Dense(10, 2, identity; init_weight = ones) first(Lux.apply(layer, ones(10, 1), Lux.setup(rng, layer)...)) end == 10 * ones(2, 1) @test begin - layer = Dense(10, 2, identity; init_weight=ones) + layer = Dense(10, 2, identity; init_weight = ones) first(Lux.apply(layer, [ones(10, 1) 2 * ones(10, 1)], Lux.setup(rng, layer)...)) end == [10 20; 10 20] @test begin - layer = Dense(10, 2, identity; init_weight=ones, bias=false) + layer = Dense(10, 2, identity; init_weight = ones, bias = false) first(Lux.apply(layer, [ones(10, 1) 2 * ones(10, 1)], Lux.setup(rng, layer)...)) end == [10 20; 10 20] end diff --git a/test/layers/conv.jl b/test/layers/conv.jl index 8be3f4e72..0fe3378fe 100644 --- a/test/layers/conv.jl +++ b/test/layers/conv.jl @@ -14,110 +14,115 @@ include("../utils.jl") @test layer(x, ps, st)[1] == maxpool(x, PoolDims(x, 2)) @test_call layer(x, ps, st) - @test_opt target_modules = (Lux,) layer(x, ps, st) + @test_opt target_modules=(Lux,) layer(x, ps, st) layer = AdaptiveMeanPool((5, 5)) ps, st = Lux.setup(rng, layer) @test layer(x, ps, st)[1] == meanpool(x, PoolDims(x, 2)) @test_call layer(x, ps, st) - @test_opt target_modules = (Lux,) layer(x, ps, st) + @test_opt target_modules=(Lux,) layer(x, ps, st) layer = AdaptiveMaxPool((10, 5)) ps, st = Lux.setup(rng, layer) @test layer(y, ps, st)[1] == maxpool(y, PoolDims(y, (2, 4))) @test_call layer(y, ps, st) - @test_opt target_modules = (Lux,) layer(y, ps, st) + @test_opt target_modules=(Lux,) layer(y, ps, st) layer = AdaptiveMeanPool((10, 5)) ps, st = Lux.setup(rng, layer) @test layer(y, ps, st)[1] == meanpool(y, PoolDims(y, (2, 4))) @test_call layer(y, ps, st) - @test_opt target_modules = (Lux,) layer(y, ps, st) + @test_opt target_modules=(Lux,) layer(y, ps, st) layer = GlobalMaxPool() ps, st = Lux.setup(rng, layer) @test size(layer(x, ps, st)[1]) == (1, 1, 3, 2) @test_call layer(x, ps, st) - @test_opt target_modules = (Lux,) layer(x, ps, st) + @test_opt target_modules=(Lux,) layer(x, ps, st) layer = GlobalMeanPool() ps, st = Lux.setup(rng, layer) @test size(layer(x, ps, st)[1]) == (1, 1, 3, 2) @test_call layer(x, ps, st) - @test_opt target_modules = (Lux,) layer(x, ps, st) + @test_opt target_modules=(Lux,) layer(x, ps, st) layer = MaxPool((2, 2)) ps, st = Lux.setup(rng, layer) @test layer(x, ps, st)[1] == maxpool(x, PoolDims(x, 2)) @test_call layer(x, ps, st) - @test_opt target_modules = (Lux,) layer(x, ps, st) + @test_opt target_modules=(Lux,) layer(x, ps, st) layer = MeanPool((2, 2)) ps, st = Lux.setup(rng, layer) @test layer(x, ps, st)[1] == meanpool(x, PoolDims(x, 2)) @test_call layer(x, ps, st) - @test_opt target_modules = (Lux,) layer(x, ps, st) + @test_opt target_modules=(Lux,) layer(x, ps, st) + + @testset "$ltype SamePad windowsize $k" for ltype in (MeanPool, MaxPool), + k in ((1,), (2,), (3,), (4, 5), (6, 7, 8)) - @testset "$ltype SamePad windowsize $k" for ltype in (MeanPool, MaxPool), k in ((1,), (2,), (3,), (4, 5), (6, 7, 8)) x = ones(Float32, (k .+ 3)..., 1, 1) - layer = ltype(k; pad=Lux.SamePad()) + layer = ltype(k; pad = Lux.SamePad()) ps, st = Lux.setup(rng, layer) @test size(layer(x, ps, st)[1])[1:(end - 2)] == cld.(size(x)[1:(end - 2)], k) - @test_call layer(x, ps, st) broken=length(k)==1 - @test_opt target_modules = (Lux,) layer(x, ps, st) + @test_call layer(x, ps, st) broken=length(k) == 1 + @test_opt target_modules=(Lux,) layer(x, ps, st) end end @testset "CNN" begin @testset "Grouped Conv" begin x = rand(rng, Float32, 4, 6, 1) - layer = Conv((3,), 6 => 2; groups=2) + layer = Conv((3,), 6 => 2; groups = 2) ps, st = Lux.setup(rng, layer) @test size(ps.weight) == (3, 3, 2) @test size(layer(x, ps, st)[1]) == (2, 2, 1) @test_call layer(x, ps, st) broken=true - @test_opt target_modules = (Lux,) layer(x, ps, st) - test_gradient_correctness_fdm((x, ps) -> sum(layer(x, ps, st)[1]), x, ps; atol=1.0f-3, rtol=1.0f-3) + @test_opt target_modules=(Lux,) layer(x, ps, st) + test_gradient_correctness_fdm((x, ps) -> sum(layer(x, ps, st)[1]), x, ps; + atol = 1.0f-3, rtol = 1.0f-3) x = rand(rng, Float32, 4, 4, 6, 1) - layer = Conv((3, 3), 6 => 2; groups=2) + layer = Conv((3, 3), 6 => 2; groups = 2) ps, st = Lux.setup(rng, layer) @test size(ps.weight) == (3, 3, 3, 2) @test size(layer(x, ps, st)[1]) == (2, 2, 2, 1) @test_call layer(x, ps, st) - @test_opt target_modules = (Lux,) layer(x, ps, st) - test_gradient_correctness_fdm((x, ps) -> sum(layer(x, ps, st)[1]), x, ps; atol=1.0f-3, rtol=1.0f-3) + @test_opt target_modules=(Lux,) layer(x, ps, st) + test_gradient_correctness_fdm((x, ps) -> sum(layer(x, ps, st)[1]), x, ps; + atol = 1.0f-3, rtol = 1.0f-3) x = rand(rng, Float32, 4, 4, 4, 6, 1) - layer = Conv((3, 3, 3), 6 => 2; groups=2) + layer = Conv((3, 3, 3), 6 => 2; groups = 2) ps, st = Lux.setup(rng, layer) @test size(ps.weight) == (3, 3, 3, 3, 2) @test size(layer(x, ps, st)[1]) == (2, 2, 2, 2, 1) @test_call layer(x, ps, st) - @test_opt target_modules = (Lux,) layer(x, ps, st) - test_gradient_correctness_fdm((x, ps) -> sum(layer(x, ps, st)[1]), x, ps; atol=1.0f-3, rtol=1.0f-3) + @test_opt target_modules=(Lux,) layer(x, ps, st) + test_gradient_correctness_fdm((x, ps) -> sum(layer(x, ps, st)[1]), x, ps; + atol = 1.0f-3, rtol = 1.0f-3) # Test that we cannot ask for non-integer multiplication factors - layer = Conv((2, 2), 3 => 10, groups=2) + layer = Conv((2, 2), 3 => 10, groups = 2) @test_throws AssertionError Lux.setup(rng, layer) - layer = Conv((2, 2), 2 => 9, groups=2) + layer = Conv((2, 2), 2 => 9, groups = 2) @test_throws AssertionError Lux.setup(rng, layer) end @testset "Asymmetric Padding" begin - layer = Conv((3, 3), 1 => 1, relu; pad=(0, 1, 1, 2)) + layer = Conv((3, 3), 1 => 1, relu; pad = (0, 1, 1, 2)) x = ones(Float32, 28, 28, 1, 1) ps, st = Lux.setup(rng, layer) @@ -134,106 +139,112 @@ end @test y_hat[end, end] ≈ 2.0 @test_call layer(x, ps, st) - @test_opt target_modules = (Lux,) layer(x, ps, st) + @test_opt target_modules=(Lux,) layer(x, ps, st) end @testset "Variable BitWidth Parameters" begin # https://github.com/FluxML/Flux.jl/issues/1421 - layer = Conv((5, 5), 10 => 20, identity; init_weight=Base.randn) + layer = Conv((5, 5), 10 => 20, identity; init_weight = Base.randn) ps, st = Lux.setup(rng, layer) - @test ps.bias isa Array{Float64,4} + @test ps.bias isa Array{Float64, 4} end @testset "Depthwise Conv" begin x = randn(rng, Float32, 4, 4, 3, 2) - layer = Conv((2, 2), 3 => 15; groups=3) + layer = Conv((2, 2), 3 => 15; groups = 3) ps, st = Lux.setup(rng, layer) @test size(layer(x, ps, st)[1], 3) == 15 @test_call layer(x, ps, st) - @test_opt target_modules = (Lux,) layer(x, ps, st) - test_gradient_correctness_fdm((x, ps) -> sum(layer(x, ps, st)[1]), x, ps; atol=1.0f-3, rtol=1.0f-3) + @test_opt target_modules=(Lux,) layer(x, ps, st) + test_gradient_correctness_fdm((x, ps) -> sum(layer(x, ps, st)[1]), x, ps; + atol = 1.0f-3, rtol = 1.0f-3) - layer = Conv((2, 2), 3 => 9; groups=3) + layer = Conv((2, 2), 3 => 9; groups = 3) ps, st = Lux.setup(rng, layer) @test size(layer(x, ps, st)[1], 3) == 9 @test_call layer(x, ps, st) - @test_opt target_modules = (Lux,) layer(x, ps, st) - test_gradient_correctness_fdm((x, ps) -> sum(layer(x, ps, st)[1]), x, ps; atol=1.0f-3, rtol=1.0f-3) + @test_opt target_modules=(Lux,) layer(x, ps, st) + test_gradient_correctness_fdm((x, ps) -> sum(layer(x, ps, st)[1]), x, ps; + atol = 1.0f-3, rtol = 1.0f-3) - layer = Conv((2, 2), 3 => 9; groups=3, bias=false) + layer = Conv((2, 2), 3 => 9; groups = 3, bias = false) ps, st = Lux.setup(rng, layer) @test size(layer(x, ps, st)[1], 3) == 9 @test_call layer(x, ps, st) - @test_opt target_modules = (Lux,) layer(x, ps, st) - test_gradient_correctness_fdm((x, ps) -> sum(layer(x, ps, st)[1]), x, ps; atol=1.0f-3, rtol=1.0f-3) + @test_opt target_modules=(Lux,) layer(x, ps, st) + test_gradient_correctness_fdm((x, ps) -> sum(layer(x, ps, st)[1]), x, ps; + atol = 1.0f-3, rtol = 1.0f-3) # Test that we cannot ask for non-integer multiplication factors - layer = Conv((2, 2), 3 => 10; groups=3) + layer = Conv((2, 2), 3 => 10; groups = 3) @test_throws AssertionError Lux.setup(rng, layer) end @testset "Conv SamePad kernelsize $k" for k in ((1,), (2,), (3,), (2, 3), (1, 2, 3)) x = ones(Float32, (k .+ 3)..., 1, 1) - layer = Conv(k, 1 => 1; pad=Lux.SamePad()) + layer = Conv(k, 1 => 1; pad = Lux.SamePad()) ps, st = Lux.setup(rng, layer) @test size(layer(x, ps, st)[1]) == size(x) - @test_call layer(x, ps, st) broken=length(k)==1 - @test_opt target_modules = (Lux,) layer(x, ps, st) - test_gradient_correctness_fdm((x, ps) -> sum(layer(x, ps, st)[1]), x, ps; atol=1.0f-3, rtol=1.0f-3) + @test_call layer(x, ps, st) broken=length(k) == 1 + @test_opt target_modules=(Lux,) layer(x, ps, st) + test_gradient_correctness_fdm((x, ps) -> sum(layer(x, ps, st)[1]), x, ps; + atol = 1.0f-3, rtol = 1.0f-3) - layer = Conv(k, 1 => 1; pad=Lux.SamePad(), dilation=k .÷ 2) + layer = Conv(k, 1 => 1; pad = Lux.SamePad(), dilation = k .÷ 2) ps, st = Lux.setup(rng, layer) @test size(layer(x, ps, st)[1]) == size(x) - @test_call layer(x, ps, st) broken=length(k)==1 - @test_opt target_modules = (Lux,) layer(x, ps, st) - test_gradient_correctness_fdm((x, ps) -> sum(layer(x, ps, st)[1]), x, ps; atol=1.0f-3, rtol=1.0f-3) + @test_call layer(x, ps, st) broken=length(k) == 1 + @test_opt target_modules=(Lux,) layer(x, ps, st) + test_gradient_correctness_fdm((x, ps) -> sum(layer(x, ps, st)[1]), x, ps; + atol = 1.0f-3, rtol = 1.0f-3) stride = 3 - layer = Conv(k, 1 => 1; pad=Lux.SamePad(), stride=stride) + layer = Conv(k, 1 => 1; pad = Lux.SamePad(), stride = stride) ps, st = Lux.setup(rng, layer) @test size(layer(x, ps, st)[1])[1:(end - 2)] == cld.(size(x)[1:(end - 2)], stride) - @test_call layer(x, ps, st) broken=length(k)==1 - @test_opt target_modules = (Lux,) layer(x, ps, st) - test_gradient_correctness_fdm((x, ps) -> sum(layer(x, ps, st)[1]), x, ps; atol=1.0f-3, rtol=1.0f-3) + @test_call layer(x, ps, st) broken=length(k) == 1 + @test_opt target_modules=(Lux,) layer(x, ps, st) + test_gradient_correctness_fdm((x, ps) -> sum(layer(x, ps, st)[1]), x, ps; + atol = 1.0f-3, rtol = 1.0f-3) end @testset "Conv with non quadratic window #700" begin - x = zeros(Float32, 7,7,1,1) - x[4,4,1,1] = 1 + x = zeros(Float32, 7, 7, 1, 1) + x[4, 4, 1, 1] = 1 - layer = Conv((3,3), 1=>1) + layer = Conv((3, 3), 1 => 1) ps, st = Lux.setup(rng, layer) - y = zeros(eltype(ps.weight),5,5,1,1) - y[2:end-1,2:end-1,1,1] = ps.weight + y = zeros(eltype(ps.weight), 5, 5, 1, 1) + y[2:(end - 1), 2:(end - 1), 1, 1] = ps.weight @test y ≈ layer(x, ps, st)[1] @test_call layer(x, ps, st) - @test_opt target_modules = (Lux,) layer(x, ps, st) + @test_opt target_modules=(Lux,) layer(x, ps, st) - layer = Conv((3,1), 1=>1) + layer = Conv((3, 1), 1 => 1) ps, st = Lux.setup(rng, layer) - y = zeros(eltype(ps.weight),5,7,1,1) - y[2:end-1,4,1,1] = ps.weight + y = zeros(eltype(ps.weight), 5, 7, 1, 1) + y[2:(end - 1), 4, 1, 1] = ps.weight @test y ≈ layer(x, ps, st)[1] @test_call layer(x, ps, st) - @test_opt target_modules = (Lux,) layer(x, ps, st) + @test_opt target_modules=(Lux,) layer(x, ps, st) - layer = Conv((1,3), 1=>1) + layer = Conv((1, 3), 1 => 1) ps, st = Lux.setup(rng, layer) - y = zeros(eltype(ps.weight),7,5,1,1) - y[4,2:end-1,1,1] = ps.weight + y = zeros(eltype(ps.weight), 7, 5, 1, 1) + y[4, 2:(end - 1), 1, 1] = ps.weight @test y ≈ layer(x, ps, st)[1] @test_call layer(x, ps, st) - @test_opt target_modules = (Lux,) layer(x, ps, st) + @test_opt target_modules=(Lux,) layer(x, ps, st) end end diff --git a/test/layers/dropout.jl b/test/layers/dropout.jl index 1a39661bf..e030fd85d 100644 --- a/test/layers/dropout.jl +++ b/test/layers/dropout.jl @@ -21,8 +21,9 @@ Random.seed!(rng, 0) @test x_ != x___ @test_call layer(x, ps, st) - @test_opt target_modules = (Lux,) layer(x, ps, st) - test_gradient_correctness_fdm(x -> sum(layer(x, ps, st)[1]), x; atol=1f-3, rtol=1f-3) + @test_opt target_modules=(Lux,) layer(x, ps, st) + test_gradient_correctness_fdm(x -> sum(layer(x, ps, st)[1]), x; atol = 1.0f-3, + rtol = 1.0f-3) st = Lux.testmode(st) @@ -47,10 +48,12 @@ end @test_call layer(x, ps, st) @test_call layer(x, ps, st_) - @test_opt target_modules = (Lux,) layer(x, ps, st) - @test_opt target_modules = (Lux,) layer(x, ps, st_) - test_gradient_correctness_fdm(x -> sum(layer(x, ps, st)[1]), x; atol=1f-3, rtol=1f-3) - test_gradient_correctness_fdm(x -> sum(layer(x, ps, st_)[1]), x; atol=1f-3, rtol=1f-3) + @test_opt target_modules=(Lux,) layer(x, ps, st) + @test_opt target_modules=(Lux,) layer(x, ps, st_) + test_gradient_correctness_fdm(x -> sum(layer(x, ps, st)[1]), x; atol = 1.0f-3, + rtol = 1.0f-3) + test_gradient_correctness_fdm(x -> sum(layer(x, ps, st_)[1]), x; atol = 1.0f-3, + rtol = 1.0f-3) st__ = Lux.update_state(st_, :update_mask, Val(true)) x___, st___ = layer(x, ps, st__) @@ -59,6 +62,7 @@ end @test x___ != x_ @test_call layer(x, ps, st__) - @test_opt target_modules = (Lux,) layer(x, ps, st__) - test_gradient_correctness_fdm(x -> sum(layer(x, ps, st__)[1]), x; atol=1f-3, rtol=1f-3) -end \ No newline at end of file + @test_opt target_modules=(Lux,) layer(x, ps, st__) + test_gradient_correctness_fdm(x -> sum(layer(x, ps, st__)[1]), x; atol = 1.0f-3, + rtol = 1.0f-3) +end diff --git a/test/layers/normalize.jl b/test/layers/normalize.jl index 62c7aeb42..9d1eb4b5e 100644 --- a/test/layers/normalize.jl +++ b/test/layers/normalize.jl @@ -6,10 +6,8 @@ rng = Random.default_rng() Random.seed!(rng, 0) @testset "BatchNorm" begin - let m = BatchNorm(2), x = [ - 1.0f0 3.0f0 5.0f0 - 2.0f0 4.0f0 6.0f0 - ] + let m = BatchNorm(2), x = [1.0f0 3.0f0 5.0f0 + 2.0f0 4.0f0 6.0f0] ps, st = Lux.setup(rng, m) @test Lux.parameterlength(m) == 2 * 2 @@ -18,7 +16,7 @@ Random.seed!(rng, 0) @test ps.scale == [1, 1] # init_scale(2) y, st_ = pullback(m, x, ps, st)[1] - @test isapprox(y, [-1.22474 0 1.22474; -1.22474 0 1.22474], atol=1.0e-5) + @test isapprox(y, [-1.22474 0 1.22474; -1.22474 0 1.22474], atol = 1.0e-5) # julia> x # 2×3 Array{Float64,2}: # 1.0 3.0 5.0 @@ -37,43 +35,47 @@ Random.seed!(rng, 0) # 2×1 Array{Float64,2}: # 1.3 # 1.3 - @test st_.running_var ≈ 0.1 .* var(x; dims=2, corrected=false) .* (3 / 2) .+ 0.9 .* [1.0, 1.0] + @test st_.running_var ≈ + 0.1 .* var(x; dims = 2, corrected = false) .* (3 / 2) .+ 0.9 .* [1.0, 1.0] st_ = Lux.testmode(st_) x′ = m(x, ps, st_)[1] - @test isapprox(x′[1], (1 .- 0.3) / sqrt(1.3), atol=1.0e-5) + @test isapprox(x′[1], (1 .- 0.3) / sqrt(1.3), atol = 1.0e-5) @inferred m(x, ps, st) @test_call m(x, ps, st) - @test_opt target_modules = (Lux,) m(x, ps, st) + @test_opt target_modules=(Lux,) m(x, ps, st) - test_gradient_correctness_fdm((x, ps) -> sum(first(m(x, ps, st))), x, ps; atol=1.0f-3, rtol=1.0f-3) + test_gradient_correctness_fdm((x, ps) -> sum(first(m(x, ps, st))), x, ps; + atol = 1.0f-3, rtol = 1.0f-3) end - let m = BatchNorm(2; track_stats=false), x = [1.0f0 3.0f0 5.0f0; 2.0f0 4.0f0 6.0f0] + let m = BatchNorm(2; track_stats = false), x = [1.0f0 3.0f0 5.0f0; 2.0f0 4.0f0 6.0f0] ps, st = Lux.setup(rng, m) @inferred m(x, ps, st) @test_call m(x, ps, st) - @test_opt target_modules = (Lux,) m(x, ps, st) + @test_opt target_modules=(Lux,) m(x, ps, st) - test_gradient_correctness_fdm((x, ps) -> sum(first(m(x, ps, st))), x, ps; atol=1.0f-3, rtol=1.0f-3) + test_gradient_correctness_fdm((x, ps) -> sum(first(m(x, ps, st))), x, ps; + atol = 1.0f-3, rtol = 1.0f-3) end # with activation function - let m = BatchNorm(2, sigmoid), x = [ - 1.0f0 3.0f0 5.0f0 - 2.0f0 4.0f0 6.0f0 - ] + let m = BatchNorm(2, sigmoid), x = [1.0f0 3.0f0 5.0f0 + 2.0f0 4.0f0 6.0f0] ps, st = Lux.setup(rng, m) st = Lux.testmode(st) y, st_ = m(x, ps, st) - @test isapprox(y, sigmoid.((x .- st_.running_mean) ./ sqrt.(st_.running_var .+ m.epsilon)), atol=1.0e-7) + @test isapprox(y, + sigmoid.((x .- st_.running_mean) ./ + sqrt.(st_.running_var .+ m.epsilon)), atol = 1.0e-7) @inferred m(x, ps, st) @test_call m(x, ps, st) - @test_opt target_modules = (Lux,) m(x, ps, st) + @test_opt target_modules=(Lux,) m(x, ps, st) - test_gradient_correctness_fdm((x, ps) -> sum(first(m(x, ps, st))), x, ps; atol=1.0f-3, rtol=1.0f-3) + test_gradient_correctness_fdm((x, ps) -> sum(first(m(x, ps, st))), x, ps; + atol = 1.0f-3, rtol = 1.0f-3) end let m = BatchNorm(2), x = reshape(Float32.(1:6), 3, 2, 1) @@ -89,15 +91,17 @@ Random.seed!(rng, 0) @test (@allocated m(x, ps, st)) < 100_000_000 @inferred m(x, ps, st) @test_call m(x, ps, st) - @test_opt target_modules = (Lux,) m(x, ps, st) + @test_opt target_modules=(Lux,) m(x, ps, st) end end @testset "GroupNorm" begin # begin tests - squeeze(x) = dropdims(x; dims=tuple(findall(size(x) .== 1)...)) # To remove all singular dimensions + squeeze(x) = dropdims(x; dims = tuple(findall(size(x) .== 1)...)) # To remove all singular dimensions + + let m = GroupNorm(4, 2; track_stats = true), sizes = (3, 4, 2), + x = reshape(collect(1:prod(sizes)), sizes) - let m = GroupNorm(4, 2; track_stats=true), sizes = (3, 4, 2), x = reshape(collect(1:prod(sizes)), sizes) @test Lux.parameterlength(m) == 2 * 4 x = Float32.(x) ps, st = Lux.setup(rng, m) @@ -135,17 +139,19 @@ end n = prod(size(x)) ÷ m.groups ÷ size(x)[end] corr = n / (n - 1) z = reshape(x, 3, 2, 2, 2) - variance = var(z; dims=(1, 2), corrected=false) - @test st_.running_var ≈ 0.1 * corr * vec(mean(variance; dims=4)) .+ 0.9 * 1 + variance = var(z; dims = (1, 2), corrected = false) + @test st_.running_var ≈ 0.1 * corr * vec(mean(variance; dims = 4)) .+ 0.9 * 1 st__ = Lux.testmode(st_) y, st__ = m(x, ps, st__) - out = (z .- reshape(st_.running_mean, 1, 1, 2, 1)) ./ sqrt.(reshape(st_.running_var, 1, 1, 2, 1) .+ 1.0f-5) - @test y ≈ reshape(out, size(x)) atol = 1.0e-5 + out = (z .- reshape(st_.running_mean, 1, 1, 2, 1)) ./ + sqrt.(reshape(st_.running_var, 1, 1, 2, 1) .+ 1.0f-5) + @test y≈reshape(out, size(x)) atol=1.0e-5 @inferred m(x, ps, st) @test_call m(x, ps, st) - @test_opt target_modules = (Lux,) m(x, ps, st) - test_gradient_correctness_fdm(ps -> sum(first(m(x, ps, st))), ps; atol=1.0f-3, rtol=1.0f-3) + @test_opt target_modules=(Lux,) m(x, ps, st) + test_gradient_correctness_fdm(ps -> sum(first(m(x, ps, st))), ps; atol = 1.0f-3, + rtol = 1.0f-3) end end diff --git a/test/layers/recurrent.jl b/test/layers/recurrent.jl index b660f664c..5265497a3 100644 --- a/test/layers/recurrent.jl +++ b/test/layers/recurrent.jl @@ -5,34 +5,30 @@ include("../utils.jl") rng = Random.default_rng() Random.seed!(rng, 0) -@testset "RNNCell" begin - for rnncell in ( - RNNCell(3 => 5, identity), - RNNCell(3 => 5, tanh), - RNNCell(3 => 5, tanh; bias=false), - RNNCell(3 => 5, identity; bias=false), - ) - println(rnncell) - ps, st = Lux.setup(rng, rnncell) - x = randn(rng, Float32, 3, 2) - h, st_ = Lux.apply(rnncell, x, ps, st) +@testset "RNNCell" begin for rnncell in (RNNCell(3 => 5, identity), + RNNCell(3 => 5, tanh), + RNNCell(3 => 5, tanh; bias = false), + RNNCell(3 => 5, identity; bias = false)) + println(rnncell) + ps, st = Lux.setup(rng, rnncell) + x = randn(rng, Float32, 3, 2) + h, st_ = Lux.apply(rnncell, x, ps, st) - @test_call rnncell(x, ps, st) - @test_opt target_modules = (Lux,) rnncell(x, ps, st) - @test_call rnncell((x, h), ps, st_) - @test_opt target_modules = (Lux,) rnncell((x, h), ps, st_) + @test_call rnncell(x, ps, st) + @test_opt target_modules=(Lux,) rnncell(x, ps, st) + @test_call rnncell((x, h), ps, st_) + @test_opt target_modules=(Lux,) rnncell((x, h), ps, st_) - function loss_loop_rnncell(p) - h, st_ = rnncell(x, p, st) - for i in 1:10 - h, st_ = rnncell((x, h), p, st_) - end - return sum(abs2, h) + function loss_loop_rnncell(p) + h, st_ = rnncell(x, p, st) + for i in 1:10 + h, st_ = rnncell((x, h), p, st_) end - - test_gradient_correctness_fdm(loss_loop_rnncell, ps, atol=1e-3, rtol=1e-3) + return sum(abs2, h) end -end + + test_gradient_correctness_fdm(loss_loop_rnncell, ps, atol = 1e-3, rtol = 1e-3) +end end @testset "LSTMCell" begin lstmcell = LSTMCell(3 => 5) @@ -42,9 +38,9 @@ end (h, c), st_ = Lux.apply(lstmcell, x, ps, st) @test_call lstmcell(x, ps, st) - @test_opt target_modules = (Lux,) lstmcell(x, ps, st) + @test_opt target_modules=(Lux,) lstmcell(x, ps, st) @test_call lstmcell((x, h, c), ps, st_) - @test_opt target_modules = (Lux,) lstmcell((x, h, c), ps, st_) + @test_opt target_modules=(Lux,) lstmcell((x, h, c), ps, st_) function loss_loop_lstmcell(p) (h, c), st_ = lstmcell(x, p, st) @@ -54,7 +50,7 @@ end return sum(abs2, h) end - test_gradient_correctness_fdm(loss_loop_lstmcell, ps, atol=1e-3, rtol=1e-3) + test_gradient_correctness_fdm(loss_loop_lstmcell, ps, atol = 1e-3, rtol = 1e-3) end @testset "GRUCell" begin @@ -65,9 +61,9 @@ end h, st_ = Lux.apply(grucell, x, ps, st) @test_call grucell(x, ps, st) - @test_opt target_modules = (Lux,) grucell(x, ps, st) + @test_opt target_modules=(Lux,) grucell(x, ps, st) @test_call grucell((x, h), ps, st_) - @test_opt target_modules = (Lux,) grucell((x, h), ps, st_) + @test_opt target_modules=(Lux,) grucell((x, h), ps, st_) function loss_loop_grucell(p) h, st_ = grucell(x, p, st) @@ -77,15 +73,15 @@ end return sum(abs2, h) end - test_gradient_correctness_fdm(loss_loop_grucell, ps, atol=1e-3, rtol=1e-3) + test_gradient_correctness_fdm(loss_loop_grucell, ps, atol = 1e-3, rtol = 1e-3) end @testset "multigate" begin x = rand(6, 5) res, (dx,) = Zygote.withgradient(x) do x - x1, _, x3 = Lux.multigate(x, Val(3)) - sum(x1) + sum(x3 .* 2) + x1, _, x3 = Lux.multigate(x, Val(3)) + sum(x1) + sum(x3 .* 2) end @test res == sum(x[1:2, :]) + 2sum(x[5:6, :]) @test dx == [ones(2, 5); zeros(2, 5); fill(2, 2, 5)] -end \ No newline at end of file +end diff --git a/test/models/convnets.jl b/test/models/convnets.jl index 5ba5eef8e..55d37dd46 100644 --- a/test/models/convnets.jl +++ b/test/models/convnets.jl @@ -10,38 +10,34 @@ end GC.gc(true) -@testset "VGG" begin - @testset "VGG($sz, batchnorm=$bn)" for sz in [11, 13, 16, 19], bn in [true, false] - m = VGG(sz, batchnorm = bn) - m2 = Lux.transform(m.layers) +@testset "VGG" begin @testset "VGG($sz, batchnorm=$bn)" for sz in [11, 13, 16, 19], + bn in [true, false] - @test size(run_model(m2, rand(Float32, 224, 224, 3, 1))) == (1000, 1) + m = VGG(sz, batchnorm = bn) + m2 = Lux.transform(m.layers) - GC.gc(true) - end -end + @test size(run_model(m2, rand(Float32, 224, 224, 3, 1))) == (1000, 1) -@testset "ResNet" begin - @testset "ResNet($sz)" for sz in [18, 34, 50, 101, 152] - m = ResNet(sz) - m2 = Lux.transform(m.layers) + GC.gc(true) +end end - @test size(run_model(m2, rand(Float32, 256, 256, 3, 1))) == (1000, 1) +@testset "ResNet" begin @testset "ResNet($sz)" for sz in [18, 34, 50, 101, 152] + m = ResNet(sz) + m2 = Lux.transform(m.layers) - GC.gc(true) - end -end + @test size(run_model(m2, rand(Float32, 256, 256, 3, 1))) == (1000, 1) -@testset "ResNeXt" begin - @testset for depth in [50, 101, 152] - m = ResNeXt(depth) - m2 = Lux.transform(m.layers) + GC.gc(true) +end end - @test size(run_model(m2, rand(Float32, 224, 224, 3, 1))) == (1000, 1) +@testset "ResNeXt" begin @testset for depth in [50, 101, 152] + m = ResNeXt(depth) + m2 = Lux.transform(m.layers) - GC.gc(true) - end -end + @test size(run_model(m2, rand(Float32, 224, 224, 3, 1))) == (1000, 1) + + GC.gc(true) +end end @testset "GoogLeNet" begin m = GoogLeNet() @@ -70,24 +66,22 @@ end GC.gc(true) end -@testset "DenseNet" begin - @testset for sz in [121, 161, 169, 201] - m = DenseNet(sz) - m2 = Lux.transform(m.layers) - - @test size(run_model(m2, rand(Float32, 224, 224, 3, 1))) == (1000, 1) - - GC.gc(true) - end -end +@testset "DenseNet" begin @testset for sz in [121, 161, 169, 201] + m = DenseNet(sz) + m2 = Lux.transform(m.layers) + + @test size(run_model(m2, rand(Float32, 224, 224, 3, 1))) == (1000, 1) + + GC.gc(true) +end end -@testset "MobileNet" verbose = true begin +@testset "MobileNet" verbose=true begin @testset "MobileNetv1" begin m = MobileNetv1() m2 = Lux.transform(m.layers) - + @test size(run_model(m2, rand(Float32, 224, 224, 3, 1))) == (1000, 1) - + GC.gc(true) end @@ -96,24 +90,22 @@ end @testset "MobileNetv2" begin m = MobileNetv2() m2 = Lux.transform(m.layers) - + @test size(run_model(m2, rand(Float32, 224, 224, 3, 1))) == (1000, 1) - + GC.gc(true) end GC.gc() - @testset "MobileNetv3" verbose = true begin - @testset for mode in [:small, :large] - m = MobileNetv3(mode) - m2 = Lux.transform(m.layers) - - @test size(run_model(m2, rand(Float32, 224, 224, 3, 1))) == (1000, 1) - - GC.gc(true) - end - end + @testset "MobileNetv3" verbose=true begin @testset for mode in [:small, :large] + m = MobileNetv3(mode) + m2 = Lux.transform(m.layers) + + @test size(run_model(m2, rand(Float32, 224, 224, 3, 1))) == (1000, 1) + + GC.gc(true) + end end end GC.gc() @@ -122,13 +114,11 @@ GC.gc() GC.gc() -@testset "ConvMixer" verbose = true begin - @testset for mode in [:base, :large, :small] - m = ConvMixer(mode) - m2 = Lux.transform(m.layers) - - @test size(run_model(m2, rand(Float32, 224, 224, 3, 1))) == (1000, 1) - - GC.gc(true) - end -end \ No newline at end of file +@testset "ConvMixer" verbose=true begin @testset for mode in [:base, :large, :small] + m = ConvMixer(mode) + m2 = Lux.transform(m.layers) + + @test size(run_model(m2, rand(Float32, 224, 224, 3, 1))) == (1000, 1) + + GC.gc(true) +end end diff --git a/test/runtests.jl b/test/runtests.jl index 1bcd2652c..804181477 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -10,6 +10,4 @@ end @time @safetestset "Functional Operations" begin include("functional.jl") end -@testset "Metalhead Models" begin - @time @safetestset "ConvNets -- ImageNet" begin include("models/convnets.jl") end -end +@testset "Metalhead Models" begin @time @safetestset "ConvNets -- ImageNet" begin include("models/convnets.jl") end end diff --git a/test/utils.jl b/test/utils.jl index aa6a03d5a..c19ab5f04 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -3,15 +3,16 @@ using Lux using Random using Zygote -function Base.isapprox(nt1::NamedTuple{fields}, nt2::NamedTuple{fields}; kwargs...) where {fields} +function Base.isapprox(nt1::NamedTuple{fields}, nt2::NamedTuple{fields}; + kwargs...) where {fields} checkapprox(xy) = isapprox(xy[1], xy[2]; kwargs...) - checkapprox(t::Tuple{Nothing,Nothing}) = true + checkapprox(t::Tuple{Nothing, Nothing}) = true return all(checkapprox, zip(values(nt1), values(nt2))) end -function Base.isapprox(t1::NTuple{N,T}, t2::NTuple{N,T}; kwargs...) where {N,T} +function Base.isapprox(t1::NTuple{N, T}, t2::NTuple{N, T}; kwargs...) where {N, T} checkapprox(xy) = isapprox(xy[1], xy[2]; kwargs...) - checkapprox(t::Tuple{Nothing,Nothing}) = true + checkapprox(t::Tuple{Nothing, Nothing}) = true return all(checkapprox, zip(t1, t2)) end @@ -32,10 +33,10 @@ function run_fwd_and_bwd(model, input, ps, st) return true end -function run_model(m::Lux.AbstractExplicitLayer, x, mode=:test) +function run_model(m::Lux.AbstractExplicitLayer, x, mode = :test) ps, st = Lux.setup(Random.default_rng(), m) if mode == :test st = Lux.testmode(st) end return Lux.apply(m, x, ps, st)[1] -end \ No newline at end of file +end From e61a15cb2ddc0227aa692337ba96eab2c1f3e872 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 24 May 2022 00:32:40 -0400 Subject: [PATCH 2/6] dont use julia 1.3 --- .github/workflows/FormatCheck.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/FormatCheck.yml b/.github/workflows/FormatCheck.yml index c5aa42e41..7b7e4866c 100644 --- a/.github/workflows/FormatCheck.yml +++ b/.github/workflows/FormatCheck.yml @@ -13,7 +13,7 @@ jobs: runs-on: ${{ matrix.os }} strategy: matrix: - julia-version: [1.3.0] + julia-version: [1.7] julia-arch: [x86] os: [ubuntu-latest] steps: From ac04b1f90caf61121e11817995696a1505de2b6a Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 24 May 2022 14:16:34 -0400 Subject: [PATCH 3/6] Change formatting style --- .JuliaFormatter.toml | 2 +- .github/workflows/FormatCheck.yml | 5 +- .github/workflows/FormatPR.yml | 2 +- docs/make.jl | 84 ++++++++++++++++--------------- examples/BayesianNN/main.jl | 14 +++--- examples/ImageNet/main.jl | 44 ++++++++-------- examples/NeuralODE/main.jl | 28 +++++------ examples/SimpleRNN/main.jl | 14 +++--- src/adapt.jl | 2 +- src/core.jl | 10 ++-- src/layers/basic.jl | 30 +++++------ src/layers/conv.jl | 50 +++++++++--------- src/layers/display.jl | 22 ++++---- src/layers/dropout.jl | 14 +++--- src/layers/normalize.jl | 64 +++++++++++------------ src/layers/recurrent.jl | 56 ++++++++++----------- src/nnlib.jl | 48 +++++++++--------- src/transform.jl | 18 +++---- src/utils.jl | 12 ++--- test/layers/basic.jl | 62 +++++++++++------------ test/layers/conv.jl | 48 +++++++++--------- test/layers/dropout.jl | 16 +++--- test/layers/normalize.jl | 28 +++++------ test/layers/recurrent.jl | 12 ++--- test/models/convnets.jl | 2 +- test/utils.jl | 2 +- 26 files changed, 347 insertions(+), 342 deletions(-) diff --git a/.JuliaFormatter.toml b/.JuliaFormatter.toml index 93a9e7665..e56de732a 100644 --- a/.JuliaFormatter.toml +++ b/.JuliaFormatter.toml @@ -1,2 +1,2 @@ style = "sciml" -whitespace_in_kwargs = true +whitespace_in_kwargs = false diff --git a/.github/workflows/FormatCheck.yml b/.github/workflows/FormatCheck.yml index 7b7e4866c..cc56827c4 100644 --- a/.github/workflows/FormatCheck.yml +++ b/.github/workflows/FormatCheck.yml @@ -26,8 +26,11 @@ jobs: # This will use the latest version by default but you can set the version like so: # # julia -e 'using Pkg; Pkg.add(PackageSpec(name="JuliaFormatter", version="0.13.0"))' + # + # FIXME: Before merging change to default release + # julia -e 'using Pkg; Pkg.add(PackageSpec(name="JuliaFormatter"))' run: | - julia -e 'using Pkg; Pkg.add(PackageSpec(name="JuliaFormatter"))' + julia -e 'using Pkg; Pkg.add(PackageSpec(url="https://github.com/YingboMa/JuliaFormatter.jl.git", rev="myb/scimlstyle"))' julia -e 'using JuliaFormatter; format(".", verbose=true)' - name: Format check run: | diff --git a/.github/workflows/FormatPR.yml b/.github/workflows/FormatPR.yml index 3a4c959aa..f2113e8e5 100644 --- a/.github/workflows/FormatPR.yml +++ b/.github/workflows/FormatPR.yml @@ -9,7 +9,7 @@ jobs: - uses: actions/checkout@v2 - name: Install JuliaFormatter and format run: | - julia -e 'import Pkg; Pkg.add("JuliaFormatter")' + julia -e 'using Pkg; Pkg.add(PackageSpec(url="https://github.com/YingboMa/JuliaFormatter.jl.git", rev="myb/scimlstyle"))' julia -e 'using JuliaFormatter; format(".")' # https://github.com/marketplace/actions/create-pull-request # https://github.com/peter-evans/create-pull-request#reference-example diff --git a/docs/make.jl b/docs/make.jl index 4fd8b12bd..dcc689636 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -2,7 +2,7 @@ using Documenter, Lux, Literate, Pkg # Precompile example dependencies Pkg.activate(joinpath(@__DIR__, "..", "examples")) -Pkg.develop(PackageSpec(; path = joinpath(@__DIR__, ".."))) +Pkg.develop(PackageSpec(; path=joinpath(@__DIR__, ".."))) Pkg.instantiate() Pkg.precompile() Pkg.activate(@__DIR__) @@ -13,8 +13,8 @@ if haskey(ENV, "GITHUB_ACTIONS") end deployconfig = Documenter.auto_detect_deploy_system() -Documenter.post_status(deployconfig; type = "pending", - repo = "github.com/avik-pal/Lux.jl.git") +Documenter.post_status(deployconfig; type="pending", + repo="github.com/avik-pal/Lux.jl.git") # Tutorials get_example_path(p) = joinpath(@__DIR__, "..", "examples", p) @@ -28,13 +28,13 @@ ADVANCED_TUTORIALS = [] ADVANCED_TUTORIAL_NAMES = [] MAPPING = Dict("beginner" => [], "intermediate" => [], "advanced" => []) -for (d, names, paths) in - (("beginner", BEGINNER_TUTORIAL_NAMES, BEGINNER_TUTORIALS), - ("intermediate", INTERMEDIATE_TUTORIAL_NAMES, INTERMEDIATE_TUTORIALS), - ("advanced", ADVANCED_TUTORIAL_NAMES, ADVANCED_TUTORIALS)) +for (d, names, paths) in (("beginner", BEGINNER_TUTORIAL_NAMES, BEGINNER_TUTORIALS), + ("intermediate", INTERMEDIATE_TUTORIAL_NAMES, + INTERMEDIATE_TUTORIALS), + ("advanced", ADVANCED_TUTORIAL_NAMES, ADVANCED_TUTORIALS)) for (n, p) in zip(names, paths) Literate.markdown(get_example_path(p), joinpath(OUTPUT, d, dirname(p)); - documenter = true) + documenter=true) push!(MAPPING[d], n => joinpath("examples/generated", d, dirname(p), splitext(basename(p))[1] * ".md")) @@ -44,39 +44,41 @@ end display(MAPPING) makedocs(; - sitename="Lux", - authors="Avik Pal et al.", - clean=true, - doctest=false, - modules=[Lux], - format=Documenter.HTML(; - prettyurls=get(ENV, "CI", nothing) == "true", - assets=["assets/custom.css"], - analytics = "G-Q8GYTEVTZ2" - ), - pages=[ - "Lux: Explicitly Parameterized Neural Networks" => "index.md", - "Introduction" => ["All about Lux" => "introduction/overview.md", "Ecosystem" => "introduction/ecosystem.md"], - "Examples" => [ - "Beginner" => MAPPING["beginner"], - "Intermediate" => MAPPING["intermediate"], - "Advanced" => MAPPING["advanced"], - "Additional Examples" => "examples.md", - ], - "API" => [ - "Layers" => "api/layers.md", - "Functional" => "api/functional.md", - "Core" => "api/core.md", - "Utilities" => "api/utilities.md", - ], - "Design Docs" => [ - "Documentation" => "design/documentation.md", - "Recurrent Neural Networks" => "design/recurrent.md", - ] - ], -) + sitename="Lux", + authors="Avik Pal et al.", + clean=true, + doctest=false, + modules=[Lux], + format=Documenter.HTML(; + prettyurls=get(ENV, "CI", nothing) == "true", + assets=["assets/custom.css"] + # analytics = "G-Q8GYTEVTZ2" + ), + pages=[ + "Lux: Explicitly Parameterized Neural Networks" => "index.md", + "Introduction" => [ + "All about Lux" => "introduction/overview.md", + "Ecosystem" => "introduction/ecosystem.md", + ], + "Examples" => [ + "Beginner" => MAPPING["beginner"], + "Intermediate" => MAPPING["intermediate"], + "Advanced" => MAPPING["advanced"], + "Additional Examples" => "examples.md", + ], + "API" => [ + "Layers" => "api/layers.md", + "Functional" => "api/functional.md", + "Core" => "api/core.md", + "Utilities" => "api/utilities.md", + ], + "Design Docs" => [ + "Documentation" => "design/documentation.md", + "Recurrent Neural Networks" => "design/recurrent.md", + ], + ]) -deploydocs(; repo = "github.com/avik-pal/Lux.jl.git", push_preview = true, - devbranch = "main") +deploydocs(; repo="github.com/avik-pal/Lux.jl.git", push_preview=true, + devbranch="main") Pkg.activate(@__DIR__) diff --git a/examples/BayesianNN/main.jl b/examples/BayesianNN/main.jl index 8f7d098aa..50d444432 100644 --- a/examples/BayesianNN/main.jl +++ b/examples/BayesianNN/main.jl @@ -52,8 +52,8 @@ function plot_data() x2 = first.(xt0s) y2 = last.(xt0s) - plt = Plots.scatter(x1, y1; color = "red", clim = (0, 1)) - Plots.scatter!(plt, x2, y2; color = "blue", clim = (0, 1)) + plt = Plots.scatter(x1, y1; color="red", clim=(0, 1)) + Plots.scatter!(plt, x2, y2; color="blue", clim=(0, 1)) return plt end @@ -135,8 +135,8 @@ _, i = findmax(ch[:lp]) i = i.I[1] ## Plot the posterior distribution with a contour plot -x1_range = collect(range(-6; stop = 6, length = 25)) -x2_range = collect(range(-6; stop = 6, length = 25)) +x1_range = collect(range(-6; stop=6, length=25)) +x2_range = collect(range(-6; stop=6, length=25)) Z = [nn_forward([x1, x2], theta[i, :])[1] for x1 in x1_range, x2 in x2_range] contour!(x1_range, x2_range, Z) @@ -157,8 +157,8 @@ end plot_data() n_end = 1500 -x1_range = collect(range(-6; stop = 6, length = 25)) -x2_range = collect(range(-6; stop = 6, length = 25)) +x1_range = collect(range(-6; stop=6, length=25)) +x2_range = collect(range(-6; stop=6, length=25)) Z = [nn_predict([x1, x2], theta, n_end)[1] for x1 in x1_range, x2 in x2_range] contour!(x1_range, x2_range, Z) @@ -170,5 +170,5 @@ n_end = 1000 anim = @gif for i in 1:n_end plot_data() Z = [nn_forward([x1, x2], theta[i, :])[1] for x1 in x1_range, x2 in x2_range] - contour!(x1_range, x2_range, Z; title = "Iteration $i", clim = (0, 1)) + contour!(x1_range, x2_range, Z; title="Iteration $i", clim=(0, 1)) end every 5 diff --git a/examples/ImageNet/main.jl b/examples/ImageNet/main.jl index e11fdc50d..58f60754d 100644 --- a/examples/ImageNet/main.jl +++ b/examples/ImageNet/main.jl @@ -26,17 +26,17 @@ import DataLoaders: LearnBase # Extending Datasets import MLUtils # Distributed Training -FluxMPI.Init(; verbose = true) +FluxMPI.Init(; verbose=true) CUDA.allowscalar(false) # unsafe_free OneHotArrays CUDA.unsafe_free!(x::OneHotArray) = CUDA.unsafe_free!(x.indices) # Image Classification Models -VGG11_BN(args...; kwargs...) = VGG11(args...; batchnorm = true, kwargs...) -VGG13_BN(args...; kwargs...) = VGG13(args...; batchnorm = true, kwargs...) -VGG16_BN(args...; kwargs...) = VGG16(args...; batchnorm = true, kwargs...) -VGG19_BN(args...; kwargs...) = VGG19(args...; batchnorm = true, kwargs...) +VGG11_BN(args...; kwargs...) = VGG11(args...; batchnorm=true, kwargs...) +VGG13_BN(args...; kwargs...) = VGG13(args...; batchnorm=true, kwargs...) +VGG16_BN(args...; kwargs...) = VGG16(args...; batchnorm=true, kwargs...) +VGG19_BN(args...; kwargs...) = VGG19(args...; batchnorm=true, kwargs...) MobileNetv3_small(args...; kwargs...) = MobileNetv3(:small, args...; kwargs...) MobileNetv3_large(args...; kwargs...) = MobileNetv3(:large, args...; kwargs...) ResNeXt50(args...; kwargs...) = ResNeXt(50, args...; kwargs...) @@ -75,7 +75,7 @@ AVAILABLE_IMAGENET_MODELS = [ IMAGENET_MODELS_DICT = Dict(string(model) => model for model in AVAILABLE_IMAGENET_MODELS) -function get_model(model_name::String, models_dict::Dict, rng, args...; warmup = true, +function get_model(model_name::String, models_dict::Dict, rng, args...; warmup=true, kwargs...) model = Lux.transform(models_dict[model_name](args...; kwargs...).layers) ps, st = Lux.setup(rng, model) .|> gpu @@ -93,8 +93,8 @@ function get_model(model_name::String, models_dict::Dict, rng, args...; warmup = end if is_distributed() - ps = FluxMPI.synchronize!(ps; root_rank = 0) - st = FluxMPI.synchronize!(st; root_rank = 0) + ps = FluxMPI.synchronize!(ps; root_rank=0) + st = FluxMPI.synchronize!(st; root_rank=0) should_log() && println("$(now()) ==> models synced across all ranks") end @@ -160,7 +160,7 @@ function parse_commandline_arguments() end # Loss Function -logitcrossentropyloss(ŷ, y) = mean(-sum(y .* logsoftmax(ŷ; dims = 1); dims = 1)) +logitcrossentropyloss(ŷ, y) = mean(-sum(y .* logsoftmax(ŷ; dims=1); dims=1)) function logitcrossentropyloss(x, y, model, ps, st) ŷ, st_ = model(x, ps, st) @@ -181,10 +181,10 @@ end update_lr(st_opt::NamedTuple, eta) = fmap(l -> update_lr(l, eta), st_opt) # Accuracy -function accuracy(ŷ, y, topk = (1,)) +function accuracy(ŷ, y, topk=(1,)) maxk = maximum(topk) - pred_labels = partialsortperm.(eachcol(ŷ), (1:maxk,), rev = true) + pred_labels = partialsortperm.(eachcol(ŷ), (1:maxk,), rev=true) true_labels = onecold(y) accuracies = Vector{Float32}(undef, length(topk)) @@ -202,11 +202,11 @@ is_distributed() = FluxMPI.Initialized() && total_workers() > 1 should_log() = !FluxMPI.Initialized() || local_rank() == 0 # Checkpointing -function save_checkpoint(state, is_best, filename = "checkpoint.pth.tar") +function save_checkpoint(state, is_best, filename="checkpoint.pth.tar") if should_log() serialize(filename, state) if is_best - cp(filename, "model_best.pth.tar"; force = true) + cp(filename, "model_best.pth.tar"; force=true) end end end @@ -260,12 +260,12 @@ function ImageDataset(folder::String, augmentation_pipeline, normalization_param "n02105855_2933.JPEG", ] remove_files = joinpath.((folder,), - joinpath.(first.(rsplit.(remove_files, "_", limit = 2)), + joinpath.(first.(rsplit.(remove_files, "_", limit=2)), remove_files)) image_files = [setdiff(Set(image_files), Set(remove_files))...] - labels = [mapping[x] for x in map(x -> x[2], rsplit.(image_files, "/", limit = 3))] + labels = [mapping[x] for x in map(x -> x[2], rsplit.(image_files, "/", limit=3))] else vallist = hcat(split.(readlines(joinpath(@__DIR__, "val_list.txt")))...) labels = parse.(Int, vallist[2, :]) .+ 1 @@ -315,7 +315,7 @@ end function AverageMeter(name::String, fmt::String) fmtstr = FormatExpr("$name {1:$fmt} ({2:$fmt})") - return AverageMeter(; fmtstr = fmtstr) + return AverageMeter(; fmtstr=fmtstr) end function update!(meter::AverageMeter, val, n::Int) @@ -333,7 +333,7 @@ struct ProgressMeter{N} meters::NTuple{N, AverageMeter} end -function ProgressMeter(num_batches::Int, meters::NTuple{N}, prefix::String = "") where {N} +function ProgressMeter(num_batches::Int, meters::NTuple{N}, prefix::String="") where {N} fmt = "%" * string(length(string(num_batches))) * "d" prefix = prefix != "" ? endswith(prefix, " ") ? prefix : prefix * " " : "" batch_fmtstr = generate_formatter("$prefix[$fmt/" * sprintf1(fmt, num_batches) * "]") @@ -464,11 +464,11 @@ function main(args) println("$(now()) => creating model `$(args["arch"])`") end end - model, ps, st = get_model(args["arch"], IMAGENET_MODELS_DICT, rng; warmup = true, - pretrain = args["pretrained"]) + model, ps, st = get_model(args["arch"], IMAGENET_MODELS_DICT, rng; warmup=true, + pretrain=args["pretrained"]) - normalization_parameters = (mean = reshape([0.485f0, 0.456f0, 0.406f0], 1, 1, 3), - std = reshape([0.229f0, 0.224f0, 0.225f0], 1, 1, 3)) + normalization_parameters = (mean=reshape([0.485f0, 0.456f0, 0.406f0], 1, 1, 3), + std=reshape([0.229f0, 0.224f0, 0.225f0], 1, 1, 3)) train_data_augmentation = Resize(256, 256) |> FlipX(0.5) |> RCropSize(224, 224) val_data_augmentation = Resize(256, 256) |> CropSize(224, 224) train_dataset = ImageDataset(joinpath(args["data"], "train"), @@ -496,7 +496,7 @@ function main(args) optimiser_state = FluxMPI.synchronize!(optimiser_state) should_log() && println("$(now()) ==> synced optimiser state across all ranks") end - scheduler = Step(λ = args["learning-rate"], γ = 0.1f0, step_sizes = 30) + scheduler = Step(λ=args["learning-rate"], γ=0.1f0, step_sizes=30) if args["resume"] != "" if isfile(args["resume"]) diff --git a/examples/NeuralODE/main.jl b/examples/NeuralODE/main.jl index bb87db840..062ad87d4 100644 --- a/examples/NeuralODE/main.jl +++ b/examples/NeuralODE/main.jl @@ -26,13 +26,13 @@ function loadmnist(batchsize, train_split) ## Process images into (H,W,C,BS) batches x_data = Float32.(reshape(imgs, size(imgs, 1), size(imgs, 2), 1, size(imgs, 3))) y_data = onehot(labels_raw) - (x_train, y_train), (x_test, y_test) = splitobs((x_data, y_data); at = train_split) + (x_train, y_train), (x_test, y_test) = splitobs((x_data, y_data); at=train_split) return ( ## Use DataLoader to automatically minibatch and shuffle the data - DataLoader(collect.((x_train, y_train)); batchsize = batchsize, shuffle = true), + DataLoader(collect.((x_train, y_train)); batchsize=batchsize, shuffle=true), ## Don't shuffle the test data - DataLoader(collect.((x_test, y_test)); batchsize = batchsize, shuffle = false)) + DataLoader(collect.((x_test, y_test)); batchsize=batchsize, shuffle=false)) end # ## Define the Neural ODE Layer @@ -49,9 +49,9 @@ struct NeuralODE{M <: Lux.AbstractExplicitLayer, So, Se, T, K} <: end function NeuralODE(model::Lux.AbstractExplicitLayer; - solver = Tsit5(), - sensealg = InterpolatingAdjoint(; autojacvec = ZygoteVJP()), - tspan = (0.0f0, 1.0f0), + solver=Tsit5(), + sensealg=InterpolatingAdjoint(; autojacvec=ZygoteVJP()), + tspan=(0.0f0, 1.0f0), kwargs...) return NeuralODE(model, solver, sensealg, tspan, kwargs) end @@ -62,13 +62,13 @@ function (n::NeuralODE)(x, ps, st) return u_ end prob = ODEProblem{false}(ODEFunction{false}(dudt), x, n.tspan, ps) - return solve(prob, n.solver; sensealg = n.sensealg, n.kwargs...), st + return solve(prob, n.solver; sensealg=n.sensealg, n.kwargs...), st end function diffeqsol_to_array(x::ODESolution{T, N, <:AbstractVector{<:CuArray}}) where {T, N} - dropdims(gpu(x); dims = 3) + dropdims(gpu(x); dims=3) end -diffeqsol_to_array(x::ODESolution) = dropdims(Array(x); dims = 3) +diffeqsol_to_array(x::ODESolution) = dropdims(Array(x); dims=3) # ## Create and Initialize the Neural ODE Layer function create_model() @@ -77,10 +77,10 @@ function create_model() Dense(784, 20, tanh), NeuralODE(Chain(Dense(20, 10, tanh), Dense(10, 10, tanh), Dense(10, 20, tanh)); - save_everystep = false, - reltol = 1.0f-3, - abstol = 1.0f-3, - save_start = false), + save_everystep=false, + reltol=1.0f-3, + abstol=1.0f-3, + save_start=false), diffeqsol_to_array, Dense(20, 10)) @@ -97,7 +97,7 @@ end # ## Define Utility Functions get_class(x) = argmax.(eachcol(x)) -logitcrossentropy(ŷ, y) = mean(-sum(y .* logsoftmax(ŷ); dims = 1)) +logitcrossentropy(ŷ, y) = mean(-sum(y .* logsoftmax(ŷ); dims=1)) function loss(x, y, model, ps, st) ŷ, st = model(x, ps, st) diff --git a/examples/SimpleRNN/main.jl b/examples/SimpleRNN/main.jl index d10a211c9..a1e817fba 100644 --- a/examples/SimpleRNN/main.jl +++ b/examples/SimpleRNN/main.jl @@ -16,7 +16,7 @@ using MLUtils, Optimisers, Zygote, NNlib, Random, Statistics # We will use MLUtils to generate 500 (noisy) clockwise and 500 (noisy) anticlockwise spirals. Using this data we will create a `MLUtils.DataLoader`. Our dataloader will give us sequences of size 2 × seq_len × batch_size and we need to predict a binary value whether the sequence is clockwise or anticlockwise -function get_dataloaders(; dataset_size = 1000, sequence_length = 50) +function get_dataloaders(; dataset_size=1000, sequence_length=50) ## Create the spirals data = [MLUtils.Datasets.make_spiral(sequence_length) for _ in 1:dataset_size] ## Get the labels @@ -25,16 +25,16 @@ function get_dataloaders(; dataset_size = 1000, sequence_length = 50) for d in data[1:(dataset_size ÷ 2)]] anticlockwise_spirals = [reshape(d[1][:, (sequence_length + 1):end], :, sequence_length, 1) for d in data[((dataset_size ÷ 2) + 1):end]] - x_data = Float32.(cat(clockwise_spirals..., anticlockwise_spirals...; dims = 3)) + x_data = Float32.(cat(clockwise_spirals..., anticlockwise_spirals...; dims=3)) ## Split the dataset - (x_train, y_train), (x_val, y_val) = splitobs((x_data, labels); at = 0.8, - shuffle = true) + (x_train, y_train), (x_val, y_val) = splitobs((x_data, labels); at=0.8, + shuffle=true) ## Create DataLoaders return ( ## Use DataLoader to automatically minibatch and shuffle the data - DataLoader(collect.((x_train, y_train)); batchsize = 128, shuffle = true), + DataLoader(collect.((x_train, y_train)); batchsize=128, shuffle=true), ## Don't shuffle the validation data - DataLoader(collect.((x_val, y_val)); batchsize = 128, shuffle = false)) + DataLoader(collect.((x_val, y_val)); batchsize=128, shuffle=false)) end # ## Creating a Classifier @@ -72,7 +72,7 @@ function (s::SpiralClassifier)(x::AbstractArray{T, 3}, ps::NamedTuple, ## After running through the sequence we will pass the output through the classifier y, st_classifier = s.classifier(h, ps.classifier, st.classifier) ## Finally remember to create the updated state - st = merge(st, (classifier = st_classifier, lstm_cell = st_lstm)) + st = merge(st, (classifier=st_classifier, lstm_cell=st_lstm)) return vec(y), st end diff --git a/src/adapt.jl b/src/adapt.jl index 45b7d3312..e6ded3a9e 100644 --- a/src/adapt.jl +++ b/src/adapt.jl @@ -48,7 +48,7 @@ Transfer `x` to GPU """ function gpu(x) check_use_cuda() - return use_cuda[] ? fmap(x -> adapt(LuxCUDAAdaptor(), x), x; exclude = _isleaf) : x + return use_cuda[] ? fmap(x -> adapt(LuxCUDAAdaptor(), x), x; exclude=_isleaf) : x end function check_use_cuda() diff --git a/src/core.jl b/src/core.jl index e3d6f3678..32b772733 100644 --- a/src/core.jl +++ b/src/core.jl @@ -71,7 +71,7 @@ function apply(model::AbstractExplicitLayer, x, ps::Union{ComponentArray, NamedT end function Base.show(io::IO, x::AbstractExplicitLayer) - __t = rsplit(string(get_typename(x)), "."; limit = 2) + __t = rsplit(string(get_typename(x)), "."; limit=2) T = length(__t) == 2 ? __t[2] : __t[1] print(io, "$T()") end @@ -105,14 +105,14 @@ end Make all occurances of `training` in state `st` `!mode` """ -testmode(st::NamedTuple, mode::Bool = true) = update_state(st, :training, Val(!mode)) +testmode(st::NamedTuple, mode::Bool=true) = update_state(st, :training, Val(!mode)) """ trainmode(x::Any, mode::Bool=true) Make all occurances of `training` in state `st` `mode` """ -trainmode(x::Any, mode::Bool = true) = testmode(x, !mode) +trainmode(x::Any, mode::Bool=true) = testmode(x, !mode) """ update_state(st::NamedTuple, key::Symbol, value; layer_check=_default_layer_check(key)) @@ -120,11 +120,11 @@ trainmode(x::Any, mode::Bool = true) = testmode(x, !mode) Recursively update all occurances of the `key` in the state `st` with the `value`. """ function update_state(st::NamedTuple, key::Symbol, value; - layer_check = _default_layer_check(key)) + layer_check=_default_layer_check(key)) function _update_state(st, key::Symbol, value) return Setfield.set(st, Setfield.PropertyLens{key}(), value) end - return fmap(_st -> _update_state(_st, key, value), st; exclude = layer_check) + return fmap(_st -> _update_state(_st, key, value), st; exclude=layer_check) end function _default_layer_check(key) diff --git a/src/layers/basic.jl b/src/layers/basic.jl index 13ab4423a..a06e4f525 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -499,7 +499,7 @@ c = Chain( """ struct Chain{T} <: AbstractExplicitContainerLayer{(:layers,)} layers::T - function Chain(xs...; disable_optimizations::Bool = false) + function Chain(xs...; disable_optimizations::Bool=false) xs = disable_optimizations ? xs : flatten_model(xs) length(xs) == 0 && return NoOpLayer() length(xs) == 1 && return first(xs) @@ -507,7 +507,7 @@ struct Chain{T} <: AbstractExplicitContainerLayer{(:layers,)} layers = NamedTuple{names}(xs) return new{typeof(layers)}(layers) end - function Chain(xs::AbstractVector; disable_optimizations::Bool = false) + function Chain(xs::AbstractVector; disable_optimizations::Bool=false) Chain(xs...; disable_optimizations) end end @@ -618,14 +618,14 @@ function Base.show(io::IO, d::Dense{bias}) where {bias} return print(io, ")") end -function Dense(mapping::Pair{<:Int, <:Int}, activation = identity; - init_weight = glorot_uniform, init_bias = zeros32, bias::Bool = true) - return Dense(first(mapping), last(mapping), activation; init_weight = init_weight, - init_bias = init_bias, bias = bias) +function Dense(mapping::Pair{<:Int, <:Int}, activation=identity; + init_weight=glorot_uniform, init_bias=zeros32, bias::Bool=true) + return Dense(first(mapping), last(mapping), activation; init_weight=init_weight, + init_bias=init_bias, bias=bias) end -function Dense(in_dims::Int, out_dims::Int, activation = identity; - init_weight = glorot_uniform, init_bias = zeros32, bias::Bool = true) +function Dense(in_dims::Int, out_dims::Int, activation=identity; + init_weight=glorot_uniform, init_bias=zeros32, bias::Bool=true) activation = NNlib.fast_act(activation) return Dense{bias, typeof(activation), typeof(init_weight), typeof(init_bias)}(activation, in_dims, @@ -636,10 +636,10 @@ end function initialparameters(rng::AbstractRNG, d::Dense{bias}) where {bias} if bias - return (weight = d.init_weight(rng, d.out_dims, d.in_dims), - bias = d.init_bias(rng, d.out_dims, 1)) + return (weight=d.init_weight(rng, d.out_dims, d.in_dims), + bias=d.init_bias(rng, d.out_dims, 1)) else - return (weight = d.init_weight(rng, d.out_dims, d.in_dims),) + return (weight=d.init_weight(rng, d.out_dims, d.in_dims),) end end @@ -724,18 +724,18 @@ function Base.show(io::IO, d::Scale) return print(io, ")") end -function Scale(dims, activation = identity; init_weight = glorot_uniform, - init_bias = zeros32, bias::Bool = true) +function Scale(dims, activation=identity; init_weight=glorot_uniform, + init_bias=zeros32, bias::Bool=true) activation = NNlib.fast_act(activation) return Scale{bias, typeof(activation), typeof(dims), typeof(init_weight), typeof(init_bias)}(activation, dims, init_weight, init_bias) end function initialparameters(rng::AbstractRNG, d::Scale{true}) - return (weight = d.init_weight(rng, d.dims), bias = d.init_bias(rng, d.dims)) + return (weight=d.init_weight(rng, d.dims), bias=d.init_bias(rng, d.dims)) end function initialparameters(rng::AbstractRNG, d::Scale{false}) - (weight = d.init_weight(rng, d.dims),) + (weight=d.init_weight(rng, d.dims),) end parameterlength(d::Scale{bias}) where {bias} = (1 + bias) * d.dims diff --git a/src/layers/conv.jl b/src/layers/conv.jl index 0e50188cf..254765931 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -56,13 +56,13 @@ end function Conv(k::NTuple{N, Integer}, ch::Pair{<:Integer, <:Integer}, - activation = identity; - init_weight = glorot_uniform, - stride = 1, - pad = 0, - dilation = 1, - groups = 1, - bias = true) where {N} + activation=identity; + init_weight=glorot_uniform, + stride=1, + pad=0, + dilation=1, + groups=1, + bias=true) where {N} stride = expand(Val(N), stride) dilation = expand(Val(N), dilation) pad = calc_padding(Conv, pad, k, dilation, stride) @@ -77,12 +77,12 @@ function Conv(k::NTuple{N, Integer}, end function initialparameters(rng::AbstractRNG, c::Conv{N, bias}) where {N, bias} - weight = convfilter(rng, c.kernel_size, c.in_chs => c.out_chs; init = c.init_weight, - groups = c.groups) + weight = convfilter(rng, c.kernel_size, c.in_chs => c.out_chs; init=c.init_weight, + groups=c.groups) return bias ? - (weight = weight, - bias = zeros(eltype(weight), ntuple(_ -> 1, N)..., c.out_chs, 1)) : - (weight = weight,) + (weight=weight, + bias=zeros(eltype(weight), ntuple(_ -> 1, N)..., c.out_chs, 1)) : + (weight=weight,) end function parameterlength(c::Conv{N, bias}) where {N, bias} @@ -92,15 +92,15 @@ end @inline function (c::Conv{N, false})(x::AbstractArray, ps::Union{ComponentArray, NamedTuple}, st::NamedTuple) where {N} - cdims = DenseConvDims(x, ps.weight; stride = c.stride, padding = c.pad, - dilation = c.dilation, groups = c.groups) + cdims = DenseConvDims(x, ps.weight; stride=c.stride, padding=c.pad, + dilation=c.dilation, groups=c.groups) return applyactivation(c.activation, conv_wrapper(x, ps.weight, cdims)), st end @inline function (c::Conv{N, true})(x::AbstractArray, ps::Union{ComponentArray, NamedTuple}, st::NamedTuple) where {N} - cdims = DenseConvDims(x, ps.weight; stride = c.stride, padding = c.pad, - dilation = c.dilation, groups = c.groups) + cdims = DenseConvDims(x, ps.weight; stride=c.stride, padding=c.pad, + dilation=c.dilation, groups=c.groups) return applyactivation(c.activation, elementwise_add(conv_wrapper(x, ps.weight, cdims), ps.bias)), st end @@ -157,14 +157,14 @@ struct MaxPool{N, M} <: AbstractExplicitLayer stride::NTuple{N, Int} end -function MaxPool(k::NTuple{N, Integer}; pad = 0, stride = k) where {N} +function MaxPool(k::NTuple{N, Integer}; pad=0, stride=k) where {N} stride = expand(Val(N), stride) pad = calc_padding(MaxPool, pad, k, 1, stride) return MaxPool{N, length(pad)}(k, pad, stride) end function (m::MaxPool{N, M})(x, ps, st::NamedTuple) where {N, M} - pdims = PoolDims(x, m.k; padding = m.pad, stride = m.stride) + pdims = PoolDims(x, m.k; padding=m.pad, stride=m.stride) return maxpool(x, pdims), st end @@ -210,14 +210,14 @@ struct MeanPool{N, M} <: AbstractExplicitLayer stride::NTuple{N, Int} end -function MeanPool(k::NTuple{N, Integer}; pad = 0, stride = k) where {N} +function MeanPool(k::NTuple{N, Integer}; pad=0, stride=k) where {N} stride = expand(Val(N), stride) pad = calc_padding(MeanPool, pad, k, 1, stride) return MeanPool{N, length(pad)}(k, pad, stride) end function (m::MeanPool{N, M})(x, ps, st::NamedTuple) where {N, M} - pdims = PoolDims(x, m.k; padding = m.pad, stride = m.stride) + pdims = PoolDims(x, m.k; padding=m.pad, stride=m.stride) return meanpool(x, pdims), st end @@ -272,7 +272,7 @@ struct Upsample{mode, S, T} <: AbstractExplicitLayer size::T end -function Upsample(mode::Symbol = :nearest; scale = nothing, size = nothing) +function Upsample(mode::Symbol=:nearest; scale=nothing, size=nothing) mode in [:nearest, :bilinear, :trilinear] || throw(ArgumentError("mode=:$mode is not supported.")) if !(isnothing(scale) ⊻ isnothing(size)) @@ -281,7 +281,7 @@ function Upsample(mode::Symbol = :nearest; scale = nothing, size = nothing) return Upsample{mode, typeof(scale), typeof(size)}(scale, size) end -Upsample(scale, mode::Symbol = :nearest) = Upsample(mode; scale) +Upsample(scale, mode::Symbol=:nearest) = Upsample(mode; scale) function (m::Upsample{:nearest})(x::AbstractArray, ps, st::NamedTuple) return NNlib.upsample_nearest(x, m.scale), st @@ -291,21 +291,21 @@ function (m::Upsample{:nearest, Int})(x::AbstractArray{T, N}, ps, return NNlib.upsample_nearest(x, ntuple(i -> m.scale, N - 2)), st end function (m::Upsample{:nearest, Nothing})(x::AbstractArray, ps, st::NamedTuple) - return NNlib.upsample_nearest(x; size = m.size), st + return NNlib.upsample_nearest(x; size=m.size), st end function (m::Upsample{:bilinear})(x::AbstractArray, ps, st::NamedTuple) return NNlib.upsample_bilinear(x, m.scale), st end function (m::Upsample{:bilinear, Nothing})(x::AbstractArray, ps, st::NamedTuple) - return NNlib.upsample_bilinear(x; size = m.size), st + return NNlib.upsample_bilinear(x; size=m.size), st end function (m::Upsample{:trilinear})(x::AbstractArray, ps, st::NamedTuple) return NNlib.upsample_trilinear(x, m.scale), st end function (m::Upsample{:trilinear, Nothing})(x::AbstractArray, ps, st::NamedTuple) - return NNlib.upsample_trilinear(x; size = m.size), st + return NNlib.upsample_trilinear(x; size=m.size), st end function Base.show(io::IO, u::Upsample{mode}) where {mode} diff --git a/src/layers/display.jl b/src/layers/display.jl index 9817407d7..641f4c829 100644 --- a/src/layers/display.jl +++ b/src/layers/display.jl @@ -8,7 +8,7 @@ function Base.show(io::IO, ::MIME"text/plain", x::AbstractExplicitContainerLayer end end -function _big_show(io::IO, obj, indent::Int = 0, name = nothing) +function _big_show(io::IO, obj, indent::Int=0, name=nothing) pre, post = "(", ")" children = _get_children(obj) if obj isa Function @@ -80,19 +80,19 @@ function Base.show(io::IO, ::MIME"text/plain", x::AbstractExplicitLayer) end end -function _layer_show(io::IO, layer, indent::Int = 0, name = nothing) +function _layer_show(io::IO, layer, indent::Int=0, name=nothing) _str = isnothing(name) ? "" : "$name = " - str = _str * sprint(show, layer; context = io) + str = _str * sprint(show, layer; context=io) print(io, " "^indent, str, indent == 0 ? "" : ",") paramlength = parameterlength(layer) if paramlength > 0 print(io, " "^max(2, (indent == 0 ? 20 : 39) - indent - length(str))) printstyled(io, "# ", underscorise(paramlength), " parameters"; - color = :light_black) + color=:light_black) nonparam = statelength(layer) if nonparam > 0 printstyled(io, ", plus ", underscorise(nonparam), - indent == 0 ? " non-trainable" : ""; color = :light_black) + indent == 0 ? " non-trainable" : ""; color=:light_black) end end return indent == 0 || println(io) @@ -104,11 +104,11 @@ function _big_finale(io::IO, m) pars = underscorise(paramlength) bytes = Base.format_bytes(Base.summarysize(m)) nonparam = underscorise(nonparamlength) - printstyled(io, " "^08, "# Total: "; color = :light_black) + printstyled(io, " "^08, "# Total: "; color=:light_black) println(io, pars, " parameters,") - printstyled(io, " "^10, "# plus "; color = :light_black) + printstyled(io, " "^10, "# plus "; color=:light_black) print(io, nonparam, " states, ") - printstyled(io, "summarysize "; color = :light_black) + printstyled(io, "summarysize "; color=:light_black) print(io, bytes, ".") return end @@ -121,11 +121,11 @@ end function _nan_show(io::IO, x) if !isempty(x) && _all(iszero, x) - printstyled(io, " (all zero)"; color = :cyan) + printstyled(io, " (all zero)"; color=:cyan) elseif _any(isnan, x) - printstyled(io, " (some NaN)"; color = :red) + printstyled(io, " (some NaN)"; color=:red) elseif _any(isinf, x) - printstyled(io, " (some Inf)"; color = :red) + printstyled(io, " (some Inf)"; color=:red) end end diff --git a/src/layers/dropout.jl b/src/layers/dropout.jl index 6b915d196..d696a8f2e 100644 --- a/src/layers/dropout.jl +++ b/src/layers/dropout.jl @@ -37,17 +37,17 @@ end function initialstates(rng::AbstractRNG, ::Dropout) # FIXME: Take PRNGs seriously randn(rng, 1) - return (rng = replicate(rng), training = Val(true)) + return (rng=replicate(rng), training=Val(true)) end -function Dropout(p; dims = :) +function Dropout(p; dims=:) @assert 0 ≤ p ≤ 1 return Dropout(p, dims) end function (d::Dropout{T})(x::AbstractArray{T}, ps, st::NamedTuple) where {T} y, _, rng = dropout(st.rng, x, d.p, d.dims, st.training) - return y, merge(st, (rng = rng,)) + return y, merge(st, (rng=rng,)) end function Base.show(io::IO, d::Dropout) @@ -97,11 +97,11 @@ end function initialstates(rng::AbstractRNG, ::VariationalHiddenDropout) # FIXME: Take PRNGs seriously randn(rng, 1) - return (rng = replicate(rng), training = Val(true), update_mask = Val(true), - mask = nothing) + return (rng=replicate(rng), training=Val(true), update_mask=Val(true), + mask=nothing) end -function VariationalHiddenDropout(p; dims = :) +function VariationalHiddenDropout(p; dims=:) @assert 0 ≤ p ≤ 1 return VariationalHiddenDropout(p, dims) end @@ -109,7 +109,7 @@ end function (d::VariationalHiddenDropout{T})(x::AbstractArray{T}, ps, st::NamedTuple) where {T} y, mask, rng, update_mask = dropout(st.rng, x, st.mask, d.p, d.dims, st.training, st.update_mask) - return y, merge(st, (mask = mask, rng = rng, update_mask = update_mask)) + return y, merge(st, (mask=mask, rng=rng, update_mask=update_mask)) end function Base.show(io::IO, d::VariationalHiddenDropout) diff --git a/src/layers/normalize.jl b/src/layers/normalize.jl index 58e5c5b09..b8a5a4db0 100644 --- a/src/layers/normalize.jl +++ b/src/layers/normalize.jl @@ -73,13 +73,13 @@ struct BatchNorm{affine, track_stats, F1, F2, F3, N} <: end function BatchNorm(chs::Int, - activation = identity; - init_bias = zeros32, - init_scale = ones32, - affine::Bool = true, - track_stats::Bool = true, - epsilon = 1.0f-5, - momentum = 0.1f0) + activation=identity; + init_bias=zeros32, + init_scale=ones32, + affine::Bool=true, + track_stats::Bool=true, + epsilon=1.0f-5, + momentum=0.1f0) activation = NNlib.fast_act(activation) return BatchNorm{affine, track_stats, typeof(activation), typeof(init_bias), typeof(init_scale), typeof(epsilon)}(activation, epsilon, momentum, @@ -87,16 +87,16 @@ function BatchNorm(chs::Int, end function initialparameters(rng::AbstractRNG, l::BatchNorm{affine}) where {affine} - return affine ? (scale = l.init_scale(rng, l.chs), bias = l.init_bias(rng, l.chs)) : + return affine ? (scale=l.init_scale(rng, l.chs), bias=l.init_bias(rng, l.chs)) : NamedTuple() end function initialstates(rng::AbstractRNG, l::BatchNorm{affine, track_stats}) where {affine, track_stats} return if track_stats - (running_mean = zeros32(rng, l.chs), running_var = ones32(rng, l.chs), - training = Val(true)) + (running_mean=zeros32(rng, l.chs), running_var=ones32(rng, l.chs), + training=Val(true)) else - (running_mean = nothing, running_var = nothing, training = Val(true)) + (running_mean=nothing, running_var=nothing, training=Val(true)) end end @@ -120,7 +120,7 @@ function (BN::BatchNorm)(x::AbstractArray{T, N}, ps, st::NamedTuple) where {T, N BN.momentum, BN.epsilon) - st = merge(st, (running_mean = xmean, running_var = xvar)) + st = merge(st, (running_mean=xmean, running_var=xvar)) return x_normalized, st end @@ -143,9 +143,9 @@ function (BN::BatchNorm{affine, track_stats})(x::Union{CuArray{T, 2}, CuArray{T, else N = ndims(x) reduce_dims = collect([1:(N - 2); N]) - running_mean2 = mean(x; dims = reduce_dims) - running_var2 = var(x; mean = running_mean2, dims = reduce_dims, - corrected = false) + running_mean2 = mean(x; dims=reduce_dims) + running_var2 = var(x; mean=running_mean2, dims=reduce_dims, + corrected=false) end end res = applyactivation(BN.activation, @@ -155,10 +155,10 @@ function (BN::BatchNorm{affine, track_stats})(x::Union{CuArray{T, 2}, CuArray{T, running_mean2, running_var2, BN.momentum; - eps = BN.epsilon, - training = istraining(st))) + eps=BN.epsilon, + training=istraining(st))) if track_stats - st = merge(st, (running_mean = running_mean2, running_var = running_var2)) + st = merge(st, (running_mean=running_mean2, running_var=running_var2)) end return res, st end @@ -248,13 +248,13 @@ end function GroupNorm(chs::Int, groups::Int, - activation = identity; - init_bias = zeros32, - init_scale = ones32, - affine::Bool = true, - track_stats::Bool = true, - epsilon = 1.0f-5, - momentum = 0.1f0) + activation=identity; + init_bias=zeros32, + init_scale=ones32, + affine::Bool=true, + track_stats::Bool=true, + epsilon=1.0f-5, + momentum=0.1f0) @assert chs % groups==0 "The number of groups ($(groups)) must divide the number of channels ($chs)" activation = NNlib.fast_act(activation) return GroupNorm{affine, track_stats, typeof(activation), typeof(init_bias), @@ -264,16 +264,16 @@ function GroupNorm(chs::Int, end function initialparameters(rng::AbstractRNG, l::GroupNorm{affine}) where {affine} - return affine ? (scale = l.init_scale(rng, l.chs), bias = l.init_bias(rng, l.chs)) : + return affine ? (scale=l.init_scale(rng, l.chs), bias=l.init_bias(rng, l.chs)) : NamedTuple() end function initialstates(rng::AbstractRNG, l::GroupNorm{affine, track_stats}) where {affine, track_stats} return if track_stats - (running_mean = zeros32(rng, l.groups), running_var = ones32(rng, l.groups), - training = Val(true)) + (running_mean=zeros32(rng, l.groups), running_var=ones32(rng, l.groups), + training=Val(true)) else - (running_mean = nothing, running_var = nothing, training = Val(true)) + (running_mean=nothing, running_var=nothing, training=Val(true)) end end @@ -300,7 +300,7 @@ function (GN::GroupNorm)(x::AbstractArray{T, N}, ps, st::NamedTuple) where {T, N GN.momentum, GN.epsilon) - st = merge(st, (running_mean = xmean, running_var = xvar)) + st = merge(st, (running_mean=xmean, running_var=xvar)) return reshape(x_normalized, sz), st end @@ -352,7 +352,7 @@ struct WeightNorm{which_params, L <: AbstractExplicitLayer, D} <: AbstractExplic end function WeightNorm(layer::AbstractExplicitLayer, which_params::NTuple{N, Symbol}, - dims::Union{Tuple, Nothing} = nothing) where {N} + dims::Union{Tuple, Nothing}=nothing) where {N} return WeightNorm{Val{which_params}, typeof(layer), typeof(dims)}(layer, dims) end @@ -374,7 +374,7 @@ function initialparameters(rng::AbstractRNG, end end ps_unnormalized = length(ps_unnormalized) == 0 ? NamedTuple() : (; ps_unnormalized...) - return (normalized = (; ps_normalized...), unnormalized = ps_unnormalized) + return (normalized=(; ps_normalized...), unnormalized=ps_unnormalized) end initialstates(rng::AbstractRNG, wn::WeightNorm) = initialstates(rng, wn.layer) diff --git a/src/layers/recurrent.jl b/src/layers/recurrent.jl index 4c492ec6f..b728577d7 100644 --- a/src/layers/recurrent.jl +++ b/src/layers/recurrent.jl @@ -45,21 +45,21 @@ struct RNNCell{bias, A, B, W, S} <: AbstractExplicitLayer end function RNNCell((in_dims, out_dims)::Pair{<:Int, <:Int}, - activation = tanh; - bias::Bool = true, - init_bias = zeros32, - init_weight = glorot_uniform, - init_state = ones32) + activation=tanh; + bias::Bool=true, + init_bias=zeros32, + init_weight=glorot_uniform, + init_state=ones32) return RNNCell{bias, typeof(activation), typeof(init_bias), typeof(init_weight), typeof(init_state)}(activation, in_dims, out_dims, init_bias, init_weight, init_state) end function initialparameters(rng::AbstractRNG, rnn::RNNCell{bias}) where {bias} - ps = (weight_ih = rnn.init_weight(rng, rnn.out_dims, rnn.in_dims), - weight_hh = rnn.init_weight(rng, rnn.out_dims, rnn.out_dims)) + ps = (weight_ih=rnn.init_weight(rng, rnn.out_dims, rnn.in_dims), + weight_hh=rnn.init_weight(rng, rnn.out_dims, rnn.out_dims)) if bias - ps = merge(ps, (bias = rnn.init_bias(rng, rnn.out_dims),)) + ps = merge(ps, (bias=rnn.init_bias(rng, rnn.out_dims),)) end return ps end @@ -67,7 +67,7 @@ end function initialstates(rng::AbstractRNG, ::RNNCell) # FIXME: Take PRNGs seriously randn(rng, 1) - return (rng = replicate(rng),) + return (rng=replicate(rng),) end function (rnn::RNNCell)(x::AbstractMatrix, ps::Union{ComponentArray, NamedTuple}, @@ -171,15 +171,15 @@ struct LSTMCell{B, W, S} <: AbstractExplicitLayer end function LSTMCell((in_dims, out_dims)::Pair{<:Int, <:Int}; - init_weight::Tuple{Function, Function, Function, Function} = (glorot_uniform, - glorot_uniform, - glorot_uniform, - glorot_uniform), - init_bias::Tuple{Function, Function, Function, Function} = (zeros32, - zeros32, - ones32, - zeros32), - init_state::Function = zeros32) + init_weight::Tuple{Function, Function, Function, Function}=(glorot_uniform, + glorot_uniform, + glorot_uniform, + glorot_uniform), + init_bias::Tuple{Function, Function, Function, Function}=(zeros32, + zeros32, + ones32, + zeros32), + init_state::Function=zeros32) return LSTMCell(in_dims, out_dims, init_bias, init_weight, init_state) end @@ -189,13 +189,13 @@ function initialparameters(rng::AbstractRNG, lstm::LSTMCell) weight_h = vcat([init_weight(rng, lstm.out_dims, lstm.out_dims) for init_weight in lstm.init_weight]...) bias = vcat([init_bias(rng, lstm.out_dims, 1) for init_bias in lstm.init_bias]...) - return (weight_i = weight_i, weight_h = weight_h, bias = bias) + return (weight_i=weight_i, weight_h=weight_h, bias=bias) end function initialstates(rng::AbstractRNG, ::LSTMCell) # FIXME: Take PRNGs seriously randn(rng, 1) - return (rng = replicate(rng),) + return (rng=replicate(rng),) end function (lstm::LSTMCell)(x::AbstractMatrix, ps::Union{ComponentArray, NamedTuple}, @@ -272,12 +272,12 @@ struct GRUCell{W, B, S} <: AbstractExplicitLayer end function GRUCell((in_dims, out_dims)::Pair{<:Int, <:Int}; - init_weight::Tuple{Function, Function, Function} = (glorot_uniform, - glorot_uniform, - glorot_uniform), - init_bias::Tuple{Function, Function, Function} = (zeros32, zeros32, - zeros32), - init_state::Function = zeros32) + init_weight::Tuple{Function, Function, Function}=(glorot_uniform, + glorot_uniform, + glorot_uniform), + init_bias::Tuple{Function, Function, Function}=(zeros32, zeros32, + zeros32), + init_state::Function=zeros32) return GRUCell(in_dims, out_dims, init_weight, init_bias, init_state) end @@ -288,13 +288,13 @@ function initialparameters(rng::AbstractRNG, gru::GRUCell) for init_weight in gru.init_weight]...) bias_i = gru.init_bias[1](rng, gru.out_dims, 1) bias_h = vcat([init_bias(rng, gru.out_dims, 1) for init_bias in gru.init_bias]...) - return (weight_i = weight_i, weight_h = weight_h, bias_i = bias_i, bias_h = bias_h) + return (weight_i=weight_i, weight_h=weight_h, bias_i=bias_i, bias_h=bias_h) end function initialstates(rng::AbstractRNG, ::GRUCell) # FIXME: Take PRNGs seriously randn(rng, 1) - return (rng = replicate(rng),) + return (rng=replicate(rng),) end function (gru::GRUCell)(x::AbstractMatrix, ps::Union{ComponentArray, NamedTuple}, diff --git a/src/nnlib.jl b/src/nnlib.jl index f6e627cd5..31c5d2f6b 100644 --- a/src/nnlib.jl +++ b/src/nnlib.jl @@ -11,8 +11,8 @@ sx = size(x) m = T(prod((sx[i] for i in reduce_dims))) if reduce_dims[end] != N - batchmean = mean(batchmean; dims = N) - batchvar = mean(batchvar; dims = N) + batchmean = mean(batchmean; dims=N) + batchvar = mean(batchvar; dims=N) end running_mean = @. (1 - momentum) * running_mean + momentum * batchmean running_var = @. (1 - momentum) * running_var + momentum * batchvar * (m / (m - one(m))) @@ -35,8 +35,8 @@ Performs BatchNorm/GroupNorm/InstanceNorm based on input configuration activation, reduce_dims, t::Val, - momentum::T = T(0.1), - epsilon::T = T(1e-5); + momentum::T=T(0.1), + epsilon::T=T(1e-5); kwargs...) where {T, N} x_norm, running_mean_, running_var_ = normalization_forward(x, reshape_into_proper_shape(running_mean, @@ -137,7 +137,7 @@ end ## TODO: Cache `1 / q` since we never need `q` @inline _dropout_kernel(y::T, p, q) where {T} = y > p ? T(1 / q) : T(0) -@inline function generate_dropout_mask(rng::AbstractRNG, x, p; dims = :) +@inline function generate_dropout_mask(rng::AbstractRNG, x, p; dims=:) realfptype = float(real(eltype(x))) y = rand!(rng, similar(x, realfptype, _dropout_shape(x, dims))) y .= _dropout_kernel.(y, p, 1 - p) @@ -188,7 +188,7 @@ end stride = insize .÷ outsize k = insize .- (outsize .- 1) .* stride pad = 0 - return PoolDims(x, k; padding = pad, stride = stride) + return PoolDims(x, k; padding=pad, stride=stride) end # CUDNN Constants @@ -216,7 +216,7 @@ Apply the function `f` on `x` elementwise, i.e. `f.(x)`. Dispatches to CUDNN if """ @inline applyactivation(f::Function, x::AbstractArray) = f.(x) @inline function applyactivation(f::cudnnValidActivationTypes, x::CuArray{<:CUDNNFloat}) - return CUDNN.cudnnActivationForward(x; mode = getCUDNNActivationMode(f)) + return CUDNN.cudnnActivationForward(x; mode=getCUDNNActivationMode(f)) end @inline applyactivation(::typeof(identity), x::AbstractArray) = x @@ -225,7 +225,7 @@ end sx = size(x) sΔ = size(Δ) sx == sΔ && return Δ - return sum(Δ; dims = findall(sx .!= sΔ)) + return sum(Δ; dims=findall(sx .!= sΔ)) end @inline isvalidtensorop(x1, x2) = false @@ -243,7 +243,7 @@ Computes `x .+ y`. Dispatches to CUDNN if possible @inline elementwise_add(x, y) = x .+ y @inline function elementwise_add(x::CuArray, y::CuArray) !isvalidtensorop(x, y) && return x .+ y - return cudnnOpTensorWithDefaults(x, y; op = CUDNN.CUDNN_OP_TENSOR_ADD) + return cudnnOpTensorWithDefaults(x, y; op=CUDNN.CUDNN_OP_TENSOR_ADD) end @inline function elementwise_add_pullback(x, y, Δ) @@ -258,7 +258,7 @@ Computes `x .* y`. Dispatches to CUDNN if possible @inline elementwise_mul(x, y) = x .* y @inline function elementwise_mul(x::CuArray, y::CuArray) !isvalidtensorop(x, y) && return x .* y - return cudnnOpTensorWithDefaults(x, y; op = CUDNN.CUDNN_OP_TENSOR_MUL) + return cudnnOpTensorWithDefaults(x, y; op=CUDNN.CUDNN_OP_TENSOR_MUL) end @inline function elementwise_mul_pullback(x, y, Δ) @@ -269,20 +269,20 @@ end # CUDNN Helpers function cudnnOpTensorWithDefaults(x1, x2; - y = similar(x1), - op::CUDNN.cudnnOpTensorOp_t = CUDNN.CUDNN_OP_TENSOR_ADD, - compType::DataType = (eltype(x1) <: Float64 ? Float64 : - Float32), - nanOpt::CUDNN.cudnnNanPropagation_t = CUDNN.CUDNN_NOT_PROPAGATE_NAN, - opTensorDesc::CUDNN.cudnnOpTensorDescriptor = CUDNN.cudnnOpTensorDescriptor(op, - CUDNN.cudnnDataType(compType), - nanOpt), - alpha1::Real = 1, - alpha2::Real = 1, - beta::Real = 0, - x1Desc::CUDNN.cudnnTensorDescriptor = CUDNN.cudnnTensorDescriptor(x1), - x2Desc::CUDNN.cudnnTensorDescriptor = CUDNN.cudnnTensorDescriptor(x2), - yDesc::CUDNN.cudnnTensorDescriptor = CUDNN.cudnnTensorDescriptor(y)) + y=similar(x1), + op::CUDNN.cudnnOpTensorOp_t=CUDNN.CUDNN_OP_TENSOR_ADD, + compType::DataType=(eltype(x1) <: Float64 ? Float64 : + Float32), + nanOpt::CUDNN.cudnnNanPropagation_t=CUDNN.CUDNN_NOT_PROPAGATE_NAN, + opTensorDesc::CUDNN.cudnnOpTensorDescriptor=CUDNN.cudnnOpTensorDescriptor(op, + CUDNN.cudnnDataType(compType), + nanOpt), + alpha1::Real=1, + alpha2::Real=1, + beta::Real=0, + x1Desc::CUDNN.cudnnTensorDescriptor=CUDNN.cudnnTensorDescriptor(x1), + x2Desc::CUDNN.cudnnTensorDescriptor=CUDNN.cudnnTensorDescriptor(x2), + yDesc::CUDNN.cudnnTensorDescriptor=CUDNN.cudnnTensorDescriptor(y)) T = eltype(x1) alpha1, alpha2, beta = CUDNN.scalingParameter.((T,), (alpha1, alpha2, beta)) return CUDNN.cudnnOpTensorAD(x1, x2; opTensorDesc, alpha1, x1Desc, alpha2, x2Desc, beta, diff --git a/src/transform.jl b/src/transform.jl index 04bd8e34e..2e56712c6 100644 --- a/src/transform.jl +++ b/src/transform.jl @@ -24,9 +24,9 @@ transform(::T) where {T} = error("Transformation for type $T not implemented") transform(model::Flux.Chain) = Chain(transform.(model.layers)...) function transform(model::Flux.BatchNorm) - return BatchNorm(model.chs, model.λ; affine = model.affine, - track_stats = model.track_stats, epsilon = model.ϵ, - momentum = model.momentum) + return BatchNorm(model.chs, model.λ; affine=model.affine, + track_stats=model.track_stats, epsilon=model.ϵ, + momentum=model.momentum) end function transform(model::Flux.Conv) @@ -34,11 +34,11 @@ function transform(model::Flux.Conv) size(model.weight, ndims(model.weight) - 1) * model.groups => size(model.weight, ndims(model.weight)), model.σ; - stride = model.stride, - pad = model.pad, - bias = model.bias isa Bool ? model.bias : !(model.bias isa Flux.Zeros), - dilation = model.dilation, - groups = model.groups) + stride=model.stride, + pad=model.pad, + bias=model.bias isa Bool ? model.bias : !(model.bias isa Flux.Zeros), + dilation=model.dilation, + groups=model.groups) end function transform(model::Flux.SkipConnection) @@ -78,7 +78,7 @@ function transform(model::Flux.Parallel) end function transform(d::Flux.Dropout) - return Dropout(Float32(d.p); dims = d.dims) + return Dropout(Float32(d.p); dims=d.dims) end transform(::typeof(identity)) = NoOpLayer() diff --git a/src/utils.jl b/src/utils.jl index 2b138b3de..b6d0cac22 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -36,7 +36,7 @@ Return an `Array{Float32}` of the given `size` containing random numbers drawn f [1] Glorot, Xavier, and Yoshua Bengio. "Understanding the difficulty of training deep feedforward neural networks." _Proceedings of the thirteenth international conference on artificial intelligence and statistics_. 2010. """ -function glorot_uniform(rng::AbstractRNG, dims::Integer...; gain::Real = 1) +function glorot_uniform(rng::AbstractRNG, dims::Integer...; gain::Real=1) scale = Float32(gain) * sqrt(24.0f0 / sum(nfan(dims...))) return (rand(rng, Float32, dims...) .- 0.5f0) .* scale end @@ -50,7 +50,7 @@ Return an `Array{Float32}` of the given `size` containing random numbers drawn f [1] Glorot, Xavier, and Yoshua Bengio. "Understanding the difficulty of training deep feedforward neural networks." _Proceedings of the thirteenth international conference on artificial intelligence and statistics_. 2010. """ -function glorot_normal(rng::AbstractRNG, dims::Integer...; gain::Real = 1) +function glorot_normal(rng::AbstractRNG, dims::Integer...; gain::Real=1) std = Float32(gain) * sqrt(2.0f0 / sum(nfan(dims...))) return randn(rng, Float32, dims...) .* std end @@ -65,15 +65,15 @@ replicate(rng::CUDA.RNG) = deepcopy(rng) @inline istraining(st::NamedTuple) = istraining(st.training) # Linear Algebra -@inline _norm(x; dims = Colon()) = sqrt.(sum(abs2, x; dims = dims)) -@inline function _norm_except(x::AbstractArray{T, N}, except_dim = N) where {T, N} - _norm(x; dims = filter(i -> i != except_dim, 1:N)) +@inline _norm(x; dims=Colon()) = sqrt.(sum(abs2, x; dims=dims)) +@inline function _norm_except(x::AbstractArray{T, N}, except_dim=N) where {T, N} + _norm(x; dims=filter(i -> i != except_dim, 1:N)) end # Convolution function convfilter(rng::AbstractRNG, filter::NTuple{N, Integer}, ch::Pair{<:Integer, <:Integer}; - init = glorot_uniform, groups = 1) where {N} + init=glorot_uniform, groups=1) where {N} cin, cout = ch @assert cin % groups==0 "Input channel dimension must be divisible by groups." @assert cout % groups==0 "Output channel dimension must be divisible by groups." diff --git a/test/layers/basic.jl b/test/layers/basic.jl index 17e0e5e53..778591686 100644 --- a/test/layers/basic.jl +++ b/test/layers/basic.jl @@ -15,8 +15,8 @@ Random.seed!(rng, 0) @test size(layer(x, ps, st)[1]) == (2, 3, 3) @test_call layer(x, ps, st) @test_opt target_modules=(Lux,) layer(x, ps, st) - test_gradient_correctness_fdm(x -> sum(layer(x, ps, st)[1]), x; atol = 1.0f-3, - rtol = 1.0f-3) + test_gradient_correctness_fdm(x -> sum(layer(x, ps, st)[1]), x; atol=1.0f-3, + rtol=1.0f-3) end @testset "Flatten Layer" begin @@ -28,23 +28,23 @@ Random.seed!(rng, 0) @test size(layer(x, ps, st)[1]) == (18, 2) @test_call layer(x, ps, st) @test_opt target_modules=(Lux,) layer(x, ps, st) - test_gradient_correctness_fdm(x -> sum(layer(x, ps, st)[1]), x; atol = 1.0f-3, - rtol = 1.0f-3) + test_gradient_correctness_fdm(x -> sum(layer(x, ps, st)[1]), x; atol=1.0f-3, + rtol=1.0f-3) end @testset "NoOpLayer" begin layer = NoOpLayer() println(layer) ps, st = Lux.setup(rng, layer) - x = (x = 2, b = 5) # Something totally arbitrary + x = (x=2, b=5) # Something totally arbitrary @test layer(x, ps, st)[1] == x @test_call layer(x, ps, st) @test_opt target_modules=(Lux,) layer(x, ps, st) x = randn(rng, 6, 3) - test_gradient_correctness_fdm(x -> sum(layer(x, ps, st)[1]), x; atol = 1.0f-3, - rtol = 1.0f-3) + test_gradient_correctness_fdm(x -> sum(layer(x, ps, st)[1]), x; atol=1.0f-3, + rtol=1.0f-3) end @testset "SelectDim Layer" begin @@ -56,8 +56,8 @@ Random.seed!(rng, 0) @test size(layer(x, ps, st)[1]) == (6, 4, 2) @test_call layer(x, ps, st) @test_opt target_modules=(Lux,) layer(x, ps, st) broken=true - test_gradient_correctness_fdm(x -> sum(layer(x, ps, st)[1]), x; atol = 1.0f-3, - rtol = 1.0f-3) + test_gradient_correctness_fdm(x -> sum(layer(x, ps, st)[1]), x; atol=1.0f-3, + rtol=1.0f-3) end @testset "WrappedFunction" begin @@ -69,8 +69,8 @@ Random.seed!(rng, 0) @test layer(x, ps, st)[1] == x .* x @test_call layer(x, ps, st) @test_opt target_modules=(Lux,) layer(x, ps, st) - test_gradient_correctness_fdm(x -> sum(layer(x, ps, st)[1]), x; atol = 1.0f-3, - rtol = 1.0f-3) + test_gradient_correctness_fdm(x -> sum(layer(x, ps, st)[1]), x; atol=1.0f-3, + rtol=1.0f-3) end @testset "ActivationFunction" begin @@ -82,8 +82,8 @@ Random.seed!(rng, 0) @test layer(x, ps, st)[1] == tanh.(x) @test_call layer(x, ps, st) @test_opt target_modules=(Lux,) layer(x, ps, st) - test_gradient_correctness_fdm(x -> sum(layer(x, ps, st)[1]), x; atol = 1.0f-3, - rtol = 1.0f-3) + test_gradient_correctness_fdm(x -> sum(layer(x, ps, st)[1]), x; atol=1.0f-3, + rtol=1.0f-3) end end @@ -98,8 +98,8 @@ end @test layer(x, ps, st)[1] == x @test_call layer(x, ps, st) @test_opt target_modules=(Lux,) layer(x, ps, st) - test_gradient_correctness_fdm(x -> sum(layer(x, ps, st)[1]), x; atol = 1.0f-3, - rtol = 1.0f-3) + test_gradient_correctness_fdm(x -> sum(layer(x, ps, st)[1]), x; atol=1.0f-3, + rtol=1.0f-3) end @testset "concat size" begin @@ -112,7 +112,7 @@ end @test_call layer(x, ps, st) @test_opt target_modules=(Lux,) layer(x, ps, st) test_gradient_correctness_fdm((x, ps) -> sum(layer(x, ps, st)[1]), x, ps; - atol = 1.0f-3, rtol = 1.0f-3) + atol=1.0f-3, rtol=1.0f-3) end end @@ -126,12 +126,12 @@ end @test layer(x, ps, st)[1] == x @test_call layer(x, ps, st) @test_opt target_modules=(Lux,) layer(x, ps, st) - test_gradient_correctness_fdm(x -> sum(layer(x, ps, st)[1]), x; atol = 1.0f-3, - rtol = 1.0f-3) + test_gradient_correctness_fdm(x -> sum(layer(x, ps, st)[1]), x; atol=1.0f-3, + rtol=1.0f-3) end @testset "concat size" begin - layer = Parallel((a, b) -> cat(a, b; dims = 2), Dense(10, 10), NoOpLayer()) + layer = Parallel((a, b) -> cat(a, b; dims=2), Dense(10, 10), NoOpLayer()) println(layer) ps, st = Lux.setup(rng, layer) x = randn(rng, 10, 2) @@ -139,8 +139,8 @@ end @test size(layer(x, ps, st)[1]) == (10, 4) @test_call layer(x, ps, st) @test_opt target_modules=(Lux,) layer(x, ps, st) - test_gradient_correctness_fdm(x -> sum(layer(x, ps, st)[1]), x; atol = 1.0f-3, - rtol = 1.0f-3) + test_gradient_correctness_fdm(x -> sum(layer(x, ps, st)[1]), x; atol=1.0f-3, + rtol=1.0f-3) layer = Parallel(hcat, Dense(10, 10), NoOpLayer()) println(layer) @@ -149,8 +149,8 @@ end @test size(layer(x, ps, st)[1]) == (10, 4) @test_call layer(x, ps, st) @test_opt target_modules=(Lux,) layer(x, ps, st) - test_gradient_correctness_fdm(x -> sum(layer(x, ps, st)[1]), x; atol = 1.0f-3, - rtol = 1.0f-3) + test_gradient_correctness_fdm(x -> sum(layer(x, ps, st)[1]), x; atol=1.0f-3, + rtol=1.0f-3) end @testset "vararg input" begin @@ -163,7 +163,7 @@ end @test_call layer(x, ps, st) @test_opt target_modules=(Lux,) layer(x, ps, st) test_gradient_correctness_fdm((x, ps) -> sum(layer(x, ps, st)[1]), x, ps; - atol = 1.0f-3, rtol = 1.0f-3) + atol=1.0f-3, rtol=1.0f-3) end @testset "connection is called once" begin @@ -189,7 +189,7 @@ end struct L1 <: Lux.AbstractExplicitLayer end (::L1)(x, ps, st) = (ps.x * x, st) - Lux.initialparameters(rng::AbstractRNG, ::L1) = (x = randn(rng, Float32, 3, 3),) + Lux.initialparameters(rng::AbstractRNG, ::L1) = (x=randn(rng, Float32, 3, 3),) Base.:*(a::AbstractArray, b::Input) = a * b.x par = Parallel(+, L1(), L1()) @@ -226,7 +226,7 @@ end @test size(ps.bias) == (100, 1) @test layer.activation == identity - layer = Dense(10, 100, relu; bias = false) + layer = Dense(10, 100, relu; bias=false) ps, st = Lux.setup(rng, layer) @test !haskey(ps, :bias) @@ -243,27 +243,27 @@ end @testset "zeros" begin @test begin - layer = Dense(10, 1, identity; init_weight = ones) + layer = Dense(10, 1, identity; init_weight=ones) first(Lux.apply(layer, ones(10, 1), Lux.setup(rng, layer)...)) end == 10 * ones(1, 1) @test begin - layer = Dense(10, 1, identity; init_weight = ones) + layer = Dense(10, 1, identity; init_weight=ones) first(Lux.apply(layer, ones(10, 2), Lux.setup(rng, layer)...)) end == 10 * ones(1, 2) @test begin - layer = Dense(10, 2, identity; init_weight = ones) + layer = Dense(10, 2, identity; init_weight=ones) first(Lux.apply(layer, ones(10, 1), Lux.setup(rng, layer)...)) end == 10 * ones(2, 1) @test begin - layer = Dense(10, 2, identity; init_weight = ones) + layer = Dense(10, 2, identity; init_weight=ones) first(Lux.apply(layer, [ones(10, 1) 2 * ones(10, 1)], Lux.setup(rng, layer)...)) end == [10 20; 10 20] @test begin - layer = Dense(10, 2, identity; init_weight = ones, bias = false) + layer = Dense(10, 2, identity; init_weight=ones, bias=false) first(Lux.apply(layer, [ones(10, 1) 2 * ones(10, 1)], Lux.setup(rng, layer)...)) end == [10 20; 10 20] end diff --git a/test/layers/conv.jl b/test/layers/conv.jl index 0fe3378fe..bb4e9142c 100644 --- a/test/layers/conv.jl +++ b/test/layers/conv.jl @@ -70,7 +70,7 @@ include("../utils.jl") x = ones(Float32, (k .+ 3)..., 1, 1) - layer = ltype(k; pad = Lux.SamePad()) + layer = ltype(k; pad=Lux.SamePad()) ps, st = Lux.setup(rng, layer) @test size(layer(x, ps, st)[1])[1:(end - 2)] == cld.(size(x)[1:(end - 2)], k) @@ -82,7 +82,7 @@ end @testset "CNN" begin @testset "Grouped Conv" begin x = rand(rng, Float32, 4, 6, 1) - layer = Conv((3,), 6 => 2; groups = 2) + layer = Conv((3,), 6 => 2; groups=2) ps, st = Lux.setup(rng, layer) @test size(ps.weight) == (3, 3, 2) @@ -90,10 +90,10 @@ end @test_call layer(x, ps, st) broken=true @test_opt target_modules=(Lux,) layer(x, ps, st) test_gradient_correctness_fdm((x, ps) -> sum(layer(x, ps, st)[1]), x, ps; - atol = 1.0f-3, rtol = 1.0f-3) + atol=1.0f-3, rtol=1.0f-3) x = rand(rng, Float32, 4, 4, 6, 1) - layer = Conv((3, 3), 6 => 2; groups = 2) + layer = Conv((3, 3), 6 => 2; groups=2) ps, st = Lux.setup(rng, layer) @test size(ps.weight) == (3, 3, 3, 2) @@ -101,10 +101,10 @@ end @test_call layer(x, ps, st) @test_opt target_modules=(Lux,) layer(x, ps, st) test_gradient_correctness_fdm((x, ps) -> sum(layer(x, ps, st)[1]), x, ps; - atol = 1.0f-3, rtol = 1.0f-3) + atol=1.0f-3, rtol=1.0f-3) x = rand(rng, Float32, 4, 4, 4, 6, 1) - layer = Conv((3, 3, 3), 6 => 2; groups = 2) + layer = Conv((3, 3, 3), 6 => 2; groups=2) ps, st = Lux.setup(rng, layer) @test size(ps.weight) == (3, 3, 3, 3, 2) @@ -112,17 +112,17 @@ end @test_call layer(x, ps, st) @test_opt target_modules=(Lux,) layer(x, ps, st) test_gradient_correctness_fdm((x, ps) -> sum(layer(x, ps, st)[1]), x, ps; - atol = 1.0f-3, rtol = 1.0f-3) + atol=1.0f-3, rtol=1.0f-3) # Test that we cannot ask for non-integer multiplication factors - layer = Conv((2, 2), 3 => 10, groups = 2) + layer = Conv((2, 2), 3 => 10, groups=2) @test_throws AssertionError Lux.setup(rng, layer) - layer = Conv((2, 2), 2 => 9, groups = 2) + layer = Conv((2, 2), 2 => 9, groups=2) @test_throws AssertionError Lux.setup(rng, layer) end @testset "Asymmetric Padding" begin - layer = Conv((3, 3), 1 => 1, relu; pad = (0, 1, 1, 2)) + layer = Conv((3, 3), 1 => 1, relu; pad=(0, 1, 1, 2)) x = ones(Float32, 28, 28, 1, 1) ps, st = Lux.setup(rng, layer) @@ -144,7 +144,7 @@ end @testset "Variable BitWidth Parameters" begin # https://github.com/FluxML/Flux.jl/issues/1421 - layer = Conv((5, 5), 10 => 20, identity; init_weight = Base.randn) + layer = Conv((5, 5), 10 => 20, identity; init_weight=Base.randn) ps, st = Lux.setup(rng, layer) @test ps.bias isa Array{Float64, 4} end @@ -152,68 +152,68 @@ end @testset "Depthwise Conv" begin x = randn(rng, Float32, 4, 4, 3, 2) - layer = Conv((2, 2), 3 => 15; groups = 3) + layer = Conv((2, 2), 3 => 15; groups=3) ps, st = Lux.setup(rng, layer) @test size(layer(x, ps, st)[1], 3) == 15 @test_call layer(x, ps, st) @test_opt target_modules=(Lux,) layer(x, ps, st) test_gradient_correctness_fdm((x, ps) -> sum(layer(x, ps, st)[1]), x, ps; - atol = 1.0f-3, rtol = 1.0f-3) + atol=1.0f-3, rtol=1.0f-3) - layer = Conv((2, 2), 3 => 9; groups = 3) + layer = Conv((2, 2), 3 => 9; groups=3) ps, st = Lux.setup(rng, layer) @test size(layer(x, ps, st)[1], 3) == 9 @test_call layer(x, ps, st) @test_opt target_modules=(Lux,) layer(x, ps, st) test_gradient_correctness_fdm((x, ps) -> sum(layer(x, ps, st)[1]), x, ps; - atol = 1.0f-3, rtol = 1.0f-3) + atol=1.0f-3, rtol=1.0f-3) - layer = Conv((2, 2), 3 => 9; groups = 3, bias = false) + layer = Conv((2, 2), 3 => 9; groups=3, bias=false) ps, st = Lux.setup(rng, layer) @test size(layer(x, ps, st)[1], 3) == 9 @test_call layer(x, ps, st) @test_opt target_modules=(Lux,) layer(x, ps, st) test_gradient_correctness_fdm((x, ps) -> sum(layer(x, ps, st)[1]), x, ps; - atol = 1.0f-3, rtol = 1.0f-3) + atol=1.0f-3, rtol=1.0f-3) # Test that we cannot ask for non-integer multiplication factors - layer = Conv((2, 2), 3 => 10; groups = 3) + layer = Conv((2, 2), 3 => 10; groups=3) @test_throws AssertionError Lux.setup(rng, layer) end @testset "Conv SamePad kernelsize $k" for k in ((1,), (2,), (3,), (2, 3), (1, 2, 3)) x = ones(Float32, (k .+ 3)..., 1, 1) - layer = Conv(k, 1 => 1; pad = Lux.SamePad()) + layer = Conv(k, 1 => 1; pad=Lux.SamePad()) ps, st = Lux.setup(rng, layer) @test size(layer(x, ps, st)[1]) == size(x) @test_call layer(x, ps, st) broken=length(k) == 1 @test_opt target_modules=(Lux,) layer(x, ps, st) test_gradient_correctness_fdm((x, ps) -> sum(layer(x, ps, st)[1]), x, ps; - atol = 1.0f-3, rtol = 1.0f-3) + atol=1.0f-3, rtol=1.0f-3) - layer = Conv(k, 1 => 1; pad = Lux.SamePad(), dilation = k .÷ 2) + layer = Conv(k, 1 => 1; pad=Lux.SamePad(), dilation=k .÷ 2) ps, st = Lux.setup(rng, layer) @test size(layer(x, ps, st)[1]) == size(x) @test_call layer(x, ps, st) broken=length(k) == 1 @test_opt target_modules=(Lux,) layer(x, ps, st) test_gradient_correctness_fdm((x, ps) -> sum(layer(x, ps, st)[1]), x, ps; - atol = 1.0f-3, rtol = 1.0f-3) + atol=1.0f-3, rtol=1.0f-3) stride = 3 - layer = Conv(k, 1 => 1; pad = Lux.SamePad(), stride = stride) + layer = Conv(k, 1 => 1; pad=Lux.SamePad(), stride=stride) ps, st = Lux.setup(rng, layer) @test size(layer(x, ps, st)[1])[1:(end - 2)] == cld.(size(x)[1:(end - 2)], stride) @test_call layer(x, ps, st) broken=length(k) == 1 @test_opt target_modules=(Lux,) layer(x, ps, st) test_gradient_correctness_fdm((x, ps) -> sum(layer(x, ps, st)[1]), x, ps; - atol = 1.0f-3, rtol = 1.0f-3) + atol=1.0f-3, rtol=1.0f-3) end @testset "Conv with non quadratic window #700" begin diff --git a/test/layers/dropout.jl b/test/layers/dropout.jl index e030fd85d..70de0f52a 100644 --- a/test/layers/dropout.jl +++ b/test/layers/dropout.jl @@ -22,8 +22,8 @@ Random.seed!(rng, 0) @test_call layer(x, ps, st) @test_opt target_modules=(Lux,) layer(x, ps, st) - test_gradient_correctness_fdm(x -> sum(layer(x, ps, st)[1]), x; atol = 1.0f-3, - rtol = 1.0f-3) + test_gradient_correctness_fdm(x -> sum(layer(x, ps, st)[1]), x; atol=1.0f-3, + rtol=1.0f-3) st = Lux.testmode(st) @@ -50,10 +50,10 @@ end @test_call layer(x, ps, st_) @test_opt target_modules=(Lux,) layer(x, ps, st) @test_opt target_modules=(Lux,) layer(x, ps, st_) - test_gradient_correctness_fdm(x -> sum(layer(x, ps, st)[1]), x; atol = 1.0f-3, - rtol = 1.0f-3) - test_gradient_correctness_fdm(x -> sum(layer(x, ps, st_)[1]), x; atol = 1.0f-3, - rtol = 1.0f-3) + test_gradient_correctness_fdm(x -> sum(layer(x, ps, st)[1]), x; atol=1.0f-3, + rtol=1.0f-3) + test_gradient_correctness_fdm(x -> sum(layer(x, ps, st_)[1]), x; atol=1.0f-3, + rtol=1.0f-3) st__ = Lux.update_state(st_, :update_mask, Val(true)) x___, st___ = layer(x, ps, st__) @@ -63,6 +63,6 @@ end @test_call layer(x, ps, st__) @test_opt target_modules=(Lux,) layer(x, ps, st__) - test_gradient_correctness_fdm(x -> sum(layer(x, ps, st__)[1]), x; atol = 1.0f-3, - rtol = 1.0f-3) + test_gradient_correctness_fdm(x -> sum(layer(x, ps, st__)[1]), x; atol=1.0f-3, + rtol=1.0f-3) end diff --git a/test/layers/normalize.jl b/test/layers/normalize.jl index 9d1eb4b5e..368c00327 100644 --- a/test/layers/normalize.jl +++ b/test/layers/normalize.jl @@ -16,7 +16,7 @@ Random.seed!(rng, 0) @test ps.scale == [1, 1] # init_scale(2) y, st_ = pullback(m, x, ps, st)[1] - @test isapprox(y, [-1.22474 0 1.22474; -1.22474 0 1.22474], atol = 1.0e-5) + @test isapprox(y, [-1.22474 0 1.22474; -1.22474 0 1.22474], atol=1.0e-5) # julia> x # 2×3 Array{Float64,2}: # 1.0 3.0 5.0 @@ -36,11 +36,11 @@ Random.seed!(rng, 0) # 1.3 # 1.3 @test st_.running_var ≈ - 0.1 .* var(x; dims = 2, corrected = false) .* (3 / 2) .+ 0.9 .* [1.0, 1.0] + 0.1 .* var(x; dims=2, corrected=false) .* (3 / 2) .+ 0.9 .* [1.0, 1.0] st_ = Lux.testmode(st_) x′ = m(x, ps, st_)[1] - @test isapprox(x′[1], (1 .- 0.3) / sqrt(1.3), atol = 1.0e-5) + @test isapprox(x′[1], (1 .- 0.3) / sqrt(1.3), atol=1.0e-5) @inferred m(x, ps, st) @@ -48,17 +48,17 @@ Random.seed!(rng, 0) @test_opt target_modules=(Lux,) m(x, ps, st) test_gradient_correctness_fdm((x, ps) -> sum(first(m(x, ps, st))), x, ps; - atol = 1.0f-3, rtol = 1.0f-3) + atol=1.0f-3, rtol=1.0f-3) end - let m = BatchNorm(2; track_stats = false), x = [1.0f0 3.0f0 5.0f0; 2.0f0 4.0f0 6.0f0] + let m = BatchNorm(2; track_stats=false), x = [1.0f0 3.0f0 5.0f0; 2.0f0 4.0f0 6.0f0] ps, st = Lux.setup(rng, m) @inferred m(x, ps, st) @test_call m(x, ps, st) @test_opt target_modules=(Lux,) m(x, ps, st) test_gradient_correctness_fdm((x, ps) -> sum(first(m(x, ps, st))), x, ps; - atol = 1.0f-3, rtol = 1.0f-3) + atol=1.0f-3, rtol=1.0f-3) end # with activation function @@ -69,13 +69,13 @@ Random.seed!(rng, 0) y, st_ = m(x, ps, st) @test isapprox(y, sigmoid.((x .- st_.running_mean) ./ - sqrt.(st_.running_var .+ m.epsilon)), atol = 1.0e-7) + sqrt.(st_.running_var .+ m.epsilon)), atol=1.0e-7) @inferred m(x, ps, st) @test_call m(x, ps, st) @test_opt target_modules=(Lux,) m(x, ps, st) test_gradient_correctness_fdm((x, ps) -> sum(first(m(x, ps, st))), x, ps; - atol = 1.0f-3, rtol = 1.0f-3) + atol=1.0f-3, rtol=1.0f-3) end let m = BatchNorm(2), x = reshape(Float32.(1:6), 3, 2, 1) @@ -97,9 +97,9 @@ end @testset "GroupNorm" begin # begin tests - squeeze(x) = dropdims(x; dims = tuple(findall(size(x) .== 1)...)) # To remove all singular dimensions + squeeze(x) = dropdims(x; dims=tuple(findall(size(x) .== 1)...)) # To remove all singular dimensions - let m = GroupNorm(4, 2; track_stats = true), sizes = (3, 4, 2), + let m = GroupNorm(4, 2; track_stats=true), sizes = (3, 4, 2), x = reshape(collect(1:prod(sizes)), sizes) @test Lux.parameterlength(m) == 2 * 4 @@ -139,8 +139,8 @@ end n = prod(size(x)) ÷ m.groups ÷ size(x)[end] corr = n / (n - 1) z = reshape(x, 3, 2, 2, 2) - variance = var(z; dims = (1, 2), corrected = false) - @test st_.running_var ≈ 0.1 * corr * vec(mean(variance; dims = 4)) .+ 0.9 * 1 + variance = var(z; dims=(1, 2), corrected=false) + @test st_.running_var ≈ 0.1 * corr * vec(mean(variance; dims=4)) .+ 0.9 * 1 st__ = Lux.testmode(st_) y, st__ = m(x, ps, st__) @@ -151,7 +151,7 @@ end @inferred m(x, ps, st) @test_call m(x, ps, st) @test_opt target_modules=(Lux,) m(x, ps, st) - test_gradient_correctness_fdm(ps -> sum(first(m(x, ps, st))), ps; atol = 1.0f-3, - rtol = 1.0f-3) + test_gradient_correctness_fdm(ps -> sum(first(m(x, ps, st))), ps; atol=1.0f-3, + rtol=1.0f-3) end end diff --git a/test/layers/recurrent.jl b/test/layers/recurrent.jl index 5265497a3..4f49a503a 100644 --- a/test/layers/recurrent.jl +++ b/test/layers/recurrent.jl @@ -6,9 +6,9 @@ rng = Random.default_rng() Random.seed!(rng, 0) @testset "RNNCell" begin for rnncell in (RNNCell(3 => 5, identity), - RNNCell(3 => 5, tanh), - RNNCell(3 => 5, tanh; bias = false), - RNNCell(3 => 5, identity; bias = false)) + RNNCell(3 => 5, tanh), + RNNCell(3 => 5, tanh; bias=false), + RNNCell(3 => 5, identity; bias=false)) println(rnncell) ps, st = Lux.setup(rng, rnncell) x = randn(rng, Float32, 3, 2) @@ -27,7 +27,7 @@ Random.seed!(rng, 0) return sum(abs2, h) end - test_gradient_correctness_fdm(loss_loop_rnncell, ps, atol = 1e-3, rtol = 1e-3) + test_gradient_correctness_fdm(loss_loop_rnncell, ps, atol=1e-3, rtol=1e-3) end end @testset "LSTMCell" begin @@ -50,7 +50,7 @@ end end return sum(abs2, h) end - test_gradient_correctness_fdm(loss_loop_lstmcell, ps, atol = 1e-3, rtol = 1e-3) + test_gradient_correctness_fdm(loss_loop_lstmcell, ps, atol=1e-3, rtol=1e-3) end @testset "GRUCell" begin @@ -73,7 +73,7 @@ end return sum(abs2, h) end - test_gradient_correctness_fdm(loss_loop_grucell, ps, atol = 1e-3, rtol = 1e-3) + test_gradient_correctness_fdm(loss_loop_grucell, ps, atol=1e-3, rtol=1e-3) end @testset "multigate" begin diff --git a/test/models/convnets.jl b/test/models/convnets.jl index 55d37dd46..42cae786e 100644 --- a/test/models/convnets.jl +++ b/test/models/convnets.jl @@ -13,7 +13,7 @@ GC.gc(true) @testset "VGG" begin @testset "VGG($sz, batchnorm=$bn)" for sz in [11, 13, 16, 19], bn in [true, false] - m = VGG(sz, batchnorm = bn) + m = VGG(sz, batchnorm=bn) m2 = Lux.transform(m.layers) @test size(run_model(m2, rand(Float32, 224, 224, 3, 1))) == (1000, 1) diff --git a/test/utils.jl b/test/utils.jl index c19ab5f04..5d5e4b386 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -33,7 +33,7 @@ function run_fwd_and_bwd(model, input, ps, st) return true end -function run_model(m::Lux.AbstractExplicitLayer, x, mode = :test) +function run_model(m::Lux.AbstractExplicitLayer, x, mode=:test) ps, st = Lux.setup(Random.default_rng(), m) if mode == :test st = Lux.testmode(st) From 5fcdcce529f1f655bb3b5c3561bbd55667a9e2f7 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 24 May 2022 15:13:07 -0400 Subject: [PATCH 4/6] rebase complete --- docs/make.jl | 1 + src/layers/basic.jl | 70 +++++++++--------------------------- src/nnlib.jl | 88 +++++++++++++++++++-------------------------- 3 files changed, 54 insertions(+), 105 deletions(-) diff --git a/docs/make.jl b/docs/make.jl index dcc689636..ed7a68c10 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -75,6 +75,7 @@ makedocs(; "Design Docs" => [ "Documentation" => "design/documentation.md", "Recurrent Neural Networks" => "design/recurrent.md", + "Add new functionality to Lux" => "design/core.md", ], ]) diff --git a/src/layers/basic.jl b/src/layers/basic.jl index a06e4f525..b4fdd8eb4 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -237,13 +237,12 @@ end st_symbols = [gensym() for _ in 1:N] getinput(i) = T <: Tuple ? :(x[$i]) : :x calls = [] - append!( - calls, - [ - :(($(y_symbols[i]), $(st_symbols[i])) = layers[$i]($(getinput(i)), ps.$(names[i]), st.$(names[i]))) for - i in 1:N - ], - ) + append!(calls, + [:(($(y_symbols[i]), $(st_symbols[i])) = layers[$i]($(getinput(i)), + ps.$(names[i]), + st.$(names[i]))) + for + i in 1:N]) push!(calls, :(st = NamedTuple{$names}((($(Tuple(st_symbols)...),))))) if C == Nothing push!(calls, :($(y_symbols[N + 1]) = tuple($(Tuple(y_symbols[1:N])...)))) @@ -322,20 +321,12 @@ end y_symbols = [gensym() for _ in 1:N] st_symbols = [gensym() for _ in 1:N] calls = [] -<<<<<<< HEAD - append!( - calls, [:(($(y_symbols[i]), $(st_symbols[i])) = layers[$i](x, ps.$(names[i]), st.$(names[i]))) for i in 1:N] - ) - push!(calls, :(st = NamedTuple{$names}((($(Tuple(st_symbols)...),))))) - push!(calls, :(return tuple($(Tuple(y_symbols)...)), st)) -======= append!(calls, [:(($(y_symbols[i]), $(st_symbols[i])) = layers[$i](x, ps.$(names[i]), st.$(names[i]))) for i in 1:N]) - append!(calls, [:(st = NamedTuple{$names}((($(Tuple(st_symbols)...),))))]) - append!(calls, [:(return tuple($(Tuple(y_symbols)...)), st)]) ->>>>>>> 862526f (enforce SciMLStyle) + push!(calls, :(st = NamedTuple{$names}((($(Tuple(st_symbols)...),))))) + push!(calls, :(return tuple($(Tuple(y_symbols)...)), st)) return Expr(:block, calls...) end @@ -417,31 +408,14 @@ end st_symbols = [gensym() for _ in 1:N] getinput(i) = T <: Tuple ? :(x[$i]) : :x calls = [:($(y_symbols[N + 1]) = $(getinput(1)))] -<<<<<<< HEAD - append!( - calls, - [ - :( - ($(y_symbols[i]), $(st_symbols[i])) = layers[$i]($(y_symbols[N + 1]), ps.$(names[i]), st.$(names[i])); - $(y_symbols[N + 1]) = connection($(y_symbols[i]), $(getinput(i + 1))) - ) - for i in 1:N - ] - ) + append!(calls, + [:(($(y_symbols[i]), $(st_symbols[i])) = layers[$i]($(y_symbols[N + 1]), + ps.$(names[i]), + st.$(names[i])); + $(y_symbols[N + 1]) = connection($(y_symbols[i]), $(getinput(i + 1)))) + for i in 1:N]) push!(calls, :(st = NamedTuple{$names}((($(Tuple(st_symbols)...),))))) push!(calls, :(return $(y_symbols[N + 1]), st)) -======= - for i in 1:N - push!(calls, - :(($(y_symbols[i]), $(st_symbols[i])) = layers[$i]($(y_symbols[N + 1]), - ps.$(names[i]), - st.$(names[i])))) - push!(calls, - :($(y_symbols[N + 1]) = connection($(y_symbols[i]), $(getinput(i + 1))))) - end - append!(calls, [:(st = NamedTuple{$names}((($(Tuple(st_symbols)...),))))]) - append!(calls, [:(return $(y_symbols[N + 1]), st)]) ->>>>>>> 862526f (enforce SciMLStyle) return Expr(:block, calls...) end @@ -548,25 +522,13 @@ end x_symbols = [gensym() for _ in 1:N] st_symbols = [gensym() for _ in 1:N] calls = [:(($(x_symbols[1]), $(st_symbols[1])) = layers[1](x, ps.layer_1, st.layer_1))] -<<<<<<< HEAD - append!( - calls, - [ - :(($(x_symbols[i]), $(st_symbols[i])) = layers[$i]($(x_symbols[i - 1]), ps.$(fields[i]), st.$(fields[i]))) - for i in 2:N - ], - ) - push!(calls, :(st = NamedTuple{$fields}((($(Tuple(st_symbols)...),))))) - push!(calls, :(return $(x_symbols[N]), st)) -======= append!(calls, [:(($(x_symbols[i]), $(st_symbols[i])) = layers[$i]($(x_symbols[i - 1]), ps.$(fields[i]), st.$(fields[i]))) for i in 2:N]) - append!(calls, [:(st = NamedTuple{$fields}((($(Tuple(st_symbols)...),))))]) - append!(calls, [:(return $(x_symbols[N]), st)]) ->>>>>>> 862526f (enforce SciMLStyle) + push!(calls, :(st = NamedTuple{$fields}((($(Tuple(st_symbols)...),))))) + push!(calls, :(return $(x_symbols[N]), st)) return Expr(:block, calls...) end diff --git a/src/nnlib.jl b/src/nnlib.jl index 31c5d2f6b..e61d3c1f1 100644 --- a/src/nnlib.jl +++ b/src/nnlib.jl @@ -56,49 +56,38 @@ Performs BatchNorm/GroupNorm/InstanceNorm based on input configuration return x_norm, safe_vec(running_mean_), safe_vec(running_var_) end -@generated function normalization_forward( - x::AbstractArray{T,N}, - running_mean::RM, - running_var::RV, - scale::S, - bias::B, - activation::A, - reduce_dims, - ::Val{training}, - momentum::T=T(0.1f0), - epsilon::T=T(1.0f-5); - kwargs..., -) where {RM,RV,S,B,T,N,A,training} +@generated function normalization_forward(x::AbstractArray{T, N}, + running_mean::RM, + running_var::RV, + scale::S, + bias::B, + activation::A, + reduce_dims, + ::Val{training}, + momentum::T=T(0.1f0), + epsilon::T=T(1.0f-5); + kwargs...) where {RM, RV, S, B, T, N, A, training} calls = [] if !training if RM == Nothing - expr = :( - batchmean = mean(x; dims=reduce_dims); - batchvar = var(x; mean=batchmean, dims=reduce_dims, corrected=false); - ) + expr = :(batchmean = mean(x; dims=reduce_dims); + batchvar = var(x; mean=batchmean, dims=reduce_dims, corrected=false)) else - expr = :( - batchmean = running_mean; - batchvar = running_var; - ) + expr = :(batchmean = running_mean; + batchvar = running_var) end push!(calls, expr) else - expr = :( - batchmean = mean(x; dims=reduce_dims); - batchvar = var(x; mean=batchmean, dims=reduce_dims, corrected=false); - ) + expr = :(batchmean = mean(x; dims=reduce_dims); + batchvar = var(x; mean=batchmean, dims=reduce_dims, corrected=false)) push!(calls, expr) if RM != Nothing - push!( - calls, - :( - (running_mean, running_var) = update_statistics( - x, running_mean, running_var, batchmean, batchvar, momentum, reduce_dims - ) - ), - ) + push!(calls, + :((running_mean, running_var) = update_statistics(x, running_mean, + running_var, batchmean, + batchvar, momentum, + reduce_dims))) end end @@ -106,7 +95,8 @@ end if A == typeof(identity) :(result = @. scale * (x - batchmean) / sqrt(batchvar + epsilon) + bias) else - :(result = @. activation(scale * (x - batchmean) / sqrt(batchvar + epsilon) + bias)) + :(result = @. activation(scale * (x - batchmean) / sqrt(batchvar + epsilon) + + bias)) end else if A == typeof(identity) @@ -150,32 +140,28 @@ end If `training` then dropout is applied on `x` with probability `prob` along `dims`. If `mask` is passed it is used if `update_mask` is false. If `update_mask` is true then the mask is generated and used. """ -@inline @generated function dropout(rng::AbstractRNG, x, prob, dims, ::Val{training}) where {training} +@inline @generated function dropout(rng::AbstractRNG, x, prob, dims, + ::Val{training}) where {training} if training - return :( - rng = replicate(rng); - mask = generate_dropout_mask(rng, x, prob; dims); - return (elementwise_mul(x, ignore_derivatives(mask)), mask, rng) - ) + return :(rng = replicate(rng); + mask = generate_dropout_mask(rng, x, prob; dims); + return (elementwise_mul(x, ignore_derivatives(mask)), mask, rng)) else return :(return (x, x, rng)) end end -@inline @generated function dropout( - rng::AbstractRNG, x, mask, prob, dims, t::Val{training}, ::Val{update_mask} -) where {training,update_mask} +@inline @generated function dropout(rng::AbstractRNG, x, mask, prob, dims, t::Val{training}, + ::Val{update_mask}) where {training, update_mask} if update_mask - return :( - (y, mask, rng) = dropout(rng, x, prob, dims, t); - return (y, mask, rng, Val(false)) - ) + return :((y, mask, rng) = dropout(rng, x, prob, dims, t); + return (y, mask, rng, Val(false))) else if training - return :( - size(x, ndims(x)) != size(mask, ndims(x)) && return (dropout(rng, x, prob, dims, t)..., Val(false)); - return (elementwise_mul(x, ignore_derivatives(mask)), mask, rng, Val(false)) - ) + return :(size(x, ndims(x)) != size(mask, ndims(x)) && + return (dropout(rng, x, prob, dims, t)..., Val(false)); + return (elementwise_mul(x, ignore_derivatives(mask)), mask, rng, + Val(false))) else return :(return (x, mask, rng, Val(false))) end From cb96e29380c1cae23f9d86c01a3b4c0b628ac1f8 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 24 May 2022 15:15:46 -0400 Subject: [PATCH 5/6] Fix ugly --- src/layers/basic.jl | 2 +- src/layers/conv.jl | 18 ++++++++++-------- 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/src/layers/basic.jl b/src/layers/basic.jl index b4fdd8eb4..269bf1e6f 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -606,7 +606,7 @@ function initialparameters(rng::AbstractRNG, d::Dense{bias}) where {bias} end function parameterlength(d::Dense{bias}) where {bias} - bias ? d.out_dims * (d.in_dims + 1) : d.out_dims * d.in_dims + return bias ? d.out_dims * (d.in_dims + 1) : d.out_dims * d.in_dims end statelength(d::Dense) = 0 diff --git a/src/layers/conv.jl b/src/layers/conv.jl index 254765931..908cbf0dd 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -79,10 +79,12 @@ end function initialparameters(rng::AbstractRNG, c::Conv{N, bias}) where {N, bias} weight = convfilter(rng, c.kernel_size, c.in_chs => c.out_chs; init=c.init_weight, groups=c.groups) - return bias ? - (weight=weight, - bias=zeros(eltype(weight), ntuple(_ -> 1, N)..., c.out_chs, 1)) : - (weight=weight,) + if bias + return (weight=weight, + bias=zeros(eltype(weight), ntuple(_ -> 1, N)..., c.out_chs, 1)) + else + return (weight=weight,) + end end function parameterlength(c::Conv{N, bias}) where {N, bias} @@ -92,15 +94,15 @@ end @inline function (c::Conv{N, false})(x::AbstractArray, ps::Union{ComponentArray, NamedTuple}, st::NamedTuple) where {N} - cdims = DenseConvDims(x, ps.weight; stride=c.stride, padding=c.pad, - dilation=c.dilation, groups=c.groups) + cdims = DenseConvDims(x, ps.weight; stride=c.stride, padding=c.pad, dilation=c.dilation, + groups=c.groups) return applyactivation(c.activation, conv_wrapper(x, ps.weight, cdims)), st end @inline function (c::Conv{N, true})(x::AbstractArray, ps::Union{ComponentArray, NamedTuple}, st::NamedTuple) where {N} - cdims = DenseConvDims(x, ps.weight; stride=c.stride, padding=c.pad, - dilation=c.dilation, groups=c.groups) + cdims = DenseConvDims(x, ps.weight; stride=c.stride, padding=c.pad, dilation=c.dilation, + groups=c.groups) return applyactivation(c.activation, elementwise_add(conv_wrapper(x, ps.weight, cdims), ps.bias)), st end From 1b813c0e0a02473ff31e2ef7905dfa9c60cdda9a Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 24 May 2022 21:35:24 -0400 Subject: [PATCH 6/6] use juliaformatter 1.0 --- .github/workflows/FormatCheck.yml | 5 +---- .github/workflows/FormatPR.yml | 2 +- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/.github/workflows/FormatCheck.yml b/.github/workflows/FormatCheck.yml index cc56827c4..7b7e4866c 100644 --- a/.github/workflows/FormatCheck.yml +++ b/.github/workflows/FormatCheck.yml @@ -26,11 +26,8 @@ jobs: # This will use the latest version by default but you can set the version like so: # # julia -e 'using Pkg; Pkg.add(PackageSpec(name="JuliaFormatter", version="0.13.0"))' - # - # FIXME: Before merging change to default release - # julia -e 'using Pkg; Pkg.add(PackageSpec(name="JuliaFormatter"))' run: | - julia -e 'using Pkg; Pkg.add(PackageSpec(url="https://github.com/YingboMa/JuliaFormatter.jl.git", rev="myb/scimlstyle"))' + julia -e 'using Pkg; Pkg.add(PackageSpec(name="JuliaFormatter"))' julia -e 'using JuliaFormatter; format(".", verbose=true)' - name: Format check run: | diff --git a/.github/workflows/FormatPR.yml b/.github/workflows/FormatPR.yml index f2113e8e5..5abd9a4df 100644 --- a/.github/workflows/FormatPR.yml +++ b/.github/workflows/FormatPR.yml @@ -9,7 +9,7 @@ jobs: - uses: actions/checkout@v2 - name: Install JuliaFormatter and format run: | - julia -e 'using Pkg; Pkg.add(PackageSpec(url="https://github.com/YingboMa/JuliaFormatter.jl.git", rev="myb/scimlstyle"))' + julia -e 'using Pkg; Pkg.add(PackageSpec(name="JuliaFormatter"))' julia -e 'using JuliaFormatter; format(".")' # https://github.com/marketplace/actions/create-pull-request # https://github.com/peter-evans/create-pull-request#reference-example