diff --git a/ext/BoltzMetalheadExt.jl b/ext/BoltzMetalheadExt.jl index 05113c7..f98fa41 100644 --- a/ext/BoltzMetalheadExt.jl +++ b/ext/BoltzMetalheadExt.jl @@ -15,10 +15,12 @@ function Vision.ResNetMetalhead(depth::Int; pretrained::Bool=false) depth; pretrain=pretrained).layers) end -function Vision.ResNeXtMetalhead(depth::Int; pretrained::Bool=false) +function Vision.ResNeXtMetalhead( + depth::Int; cardinality=32, base_width=nothing, pretrained::Bool=false) @argcheck depth in (50, 101, 152) + base_width = base_width === nothing ? (depth == 101 ? 8 : 4) : base_width return FromFluxAdaptor(; preserve_ps_st=pretrained, force_preserve=true)(Metalhead.ResNeXt( - depth; pretrain=pretrained).layers) + depth; pretrain=pretrained, cardinality, base_width).layers) end function Vision.GoogLeNetMetalhead(; pretrained::Bool=false) diff --git a/src/vision/extensions.jl b/src/vision/extensions.jl index f94a4ae..7a20d9b 100644 --- a/src/vision/extensions.jl +++ b/src/vision/extensions.jl @@ -15,7 +15,7 @@ Create a ResNet model [he2016deep](@citep). function ResNet end """ - ResNeXt(depth::Int; pretrained::Bool=false) + ResNeXt(depth::Int; cardinality=32, base_width=nothing, pretrained::Bool=false) Create a ResNeXt model [xie2017aggregated](@citep). @@ -27,6 +27,9 @@ Create a ResNeXt model [xie2017aggregated](@citep). - `pretrained::Bool=false`: If `true`, loads pretrained weights when `LuxCore.setup` is called. + - `cardinality`: The cardinality of the ResNeXt model. Defaults to 32. + - `base_width`: The base width of the ResNeXt model. Defaults to 8 for depth 101 and 4 + otherwise. """ function ResNeXt end @@ -132,12 +135,12 @@ for f in [:ResNet, :ResNeXt, :GoogLeNet, :DenseNet, f_metalhead = Symbol(f, :Metalhead) @eval begin function $(f_metalhead) end - function $(f)(args...; pretrained::Bool=false) + function $(f)(args...; pretrained::Bool=false, kwargs...) if !is_extension_loaded(Val(:Metalhead)) error("`Metalhead.jl` is not loaded. Please load `Metalhead.jl` to use \ this function.") end - model = $(f_metalhead)(args...; pretrained) + model = $(f_metalhead)(args...; pretrained, kwargs...) return MetalheadWrapperLayer(model, :metalhead, false) end end