diff --git a/Project.toml b/Project.toml index 1364ae8b..7225334c 100644 --- a/Project.toml +++ b/Project.toml @@ -19,6 +19,7 @@ LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623" MLDataDevices = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40" Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" +Preferences = "21216c6a-2e73-6563-6e65-726566657250" Polyester = "f517fe37-dbe3-4b94-8317-1923a5111588" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" @@ -78,6 +79,7 @@ MLDataDevices = "1.2" Markdown = "1.10" NNlib = "0.9.24" Octavian = "0.3.28" +Preferences = "1.4.3" Polyester = "0.7.15" Random = "1.10" Reexport = "1" diff --git a/src/traits.jl b/src/traits.jl index cc8c2937..02713457 100644 --- a/src/traits.jl +++ b/src/traits.jl @@ -8,6 +8,7 @@ using Static: True, False, static using StaticArraysCore: StaticArray using ..LuxLib: Numeric +using ..LuxLibPreferences: DISABLE_LOOP_VECTORIZATION using ..Utils: NotaNumber, only_derivative, unrolled_any, unrolled_map function fast_scalar_indexing(::T) where {T <: AbstractArray} @@ -130,9 +131,13 @@ end CRC.@non_differentiable explicit_blas_loaded() -function use_octavian() - return is_extension_loaded(Val(:Octavian)) & is_x86_64() & - (INTEL_HARDWARE | AMD_RYZEN_HARDWARE) +@static if DISABLE_LOOP_VECTORIZATION + use_octavian() = False() +else + function use_octavian() + return is_extension_loaded(Val(:Octavian)) & is_x86_64() & + (INTEL_HARDWARE | AMD_RYZEN_HARDWARE) + end end CRC.@non_differentiable use_octavian() diff --git a/src/utils.jl b/src/utils.jl index eaa60f08..1e1e6e35 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -1,3 +1,14 @@ +module LuxLibPreferences + +using Preferences: load_preference + +using ..LuxLib: LuxLib + +const DISABLE_LOOP_VECTORIZATION = load_preference( + LuxLib, "disable_loop_vectorization", false) + +end + module Utils using ChainRulesCore: ChainRulesCore @@ -12,6 +23,7 @@ using Static: Static, StaticBool, False, True, static using StaticArraysCore: SVector, SMatrix using ..LuxLib: Optional, ∂∅ +using ..LuxLibPreferences: DISABLE_LOOP_VECTORIZATION const CRC = ChainRulesCore const KA = KernelAbstractions @@ -325,8 +337,12 @@ end CRC.@non_differentiable static_training_mode_check(::Any...) -@inline function can_loopvec_args(args...) - return can_loopvec_args_check(is_extension_loaded(Val(:LoopVectorization)), args...) +@static if DISABLE_LOOP_VECTORIZATION + @inline can_loopvec_args(args...) = false +else + @inline function can_loopvec_args(args...) + return can_loopvec_args_check(is_extension_loaded(Val(:LoopVectorization)), args...) + end end @inline can_loopvec_args_check(::False, args...) = false