Skip to content

Commit

Permalink
Update blas.jl
Browse files Browse the repository at this point in the history
1. simplify variable initialization
2. add fabs
  • Loading branch information
N5N3 committed Nov 6, 2021
1 parent 4181e91 commit b8f4ab7
Showing 1 changed file with 30 additions and 38 deletions.
68 changes: 30 additions & 38 deletions stdlib/LinearAlgebra/test/blas.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,16 @@ module TestBLAS

using Test, LinearAlgebra, Random
using LinearAlgebra: BlasReal, BlasComplex
fabs(x::Real) = abs(x)
fabs(x::Complex) = abs(real(x)) + abs(imag(x))

Random.seed!(100)
## BLAS tests - testing the interface code to BLAS routines
@testset for elty in [Float32, Float64, ComplexF32, ComplexF64]

@testset "syr2k!" begin
U = randn(5,2)
V = randn(5,2)
if elty == ComplexF32 || elty == ComplexF64
U = complex.(U, U)
V = complex.(V, V)
end
U = convert(Array{elty, 2}, U)
V = convert(Array{elty, 2}, V)
U = randn(elty, 5, 2)
V = randn(elty, 5, 2)
@test tril(LinearAlgebra.BLAS.syr2k('L','N',U,V)) tril(U*transpose(V) + V*transpose(U))
@test triu(LinearAlgebra.BLAS.syr2k('U','N',U,V)) triu(U*transpose(V) + V*transpose(U))
@test tril(LinearAlgebra.BLAS.syr2k('L','T',U,V)) tril(transpose(U)*V + transpose(V)*U)
Expand All @@ -26,12 +22,8 @@ Random.seed!(100)

if elty in (ComplexF32, ComplexF64)
@testset "her2k!" begin
U = randn(5,2)
V = randn(5,2)
U = complex.(U, U)
V = complex.(V, V)
U = convert(Array{elty, 2}, U)
V = convert(Array{elty, 2}, V)
U = randn(elty, 5, 2)
V = randn(elty, 5, 2)
@test tril(LinearAlgebra.BLAS.her2k('L','N',U,V)) tril(U*V' + V*U')
@test triu(LinearAlgebra.BLAS.her2k('U','N',U,V)) triu(U*V' + V*U')
@test tril(LinearAlgebra.BLAS.her2k('L','C',U,V)) tril(U'*V + V'*U)
Expand All @@ -48,34 +40,34 @@ Random.seed!(100)
U4 = triu(fill(elty(1), 4,4))
Z4 = zeros(elty, (4,4))

elm1 = convert(elty, -1)
el2 = convert(elty, 2)
v14 = convert(Vector{elty}, [1:4;])
v41 = convert(Vector{elty}, [4:-1:1;])
elm1 = elty(-1)
el2 = elty(2)
v14 = elty[1:4;]
v41 = elty[4:-1:1;]

let n = 10
@testset "dot products" begin
if elty <: Real
x1 = convert(Vector{elty}, randn(n))
x2 = convert(Vector{elty}, randn(n))
x1 = randn(elty, n)
x2 = randn(elty, n)
@test BLAS.dot(x1,x2) sum(x1.*x2)
@test_throws DimensionMismatch BLAS.dot(x1,rand(elty, n + 1))
else
z1 = convert(Vector{elty}, complex.(randn(n),randn(n)))
z2 = convert(Vector{elty}, complex.(randn(n),randn(n)))
z1 = randn(elty, n)
z2 = randn(elty, n)
@test BLAS.dotc(z1,z2) sum(conj(z1).*z2)
@test BLAS.dotu(z1,z2) sum(z1.*z2)
@test_throws DimensionMismatch BLAS.dotc(z1,rand(elty, n + 1))
@test_throws DimensionMismatch BLAS.dotu(z1,rand(elty, n + 1))
end
end
@testset "iamax" begin
x = convert(Vector{elty}, randn(elty, n))
@test BLAS.iamax(x) == findmax(x -> abs(real(x)) + abs(imag(x)), x)[2]
x = randn(elty, n)
@test BLAS.iamax(x) == findmax(fabs, x)[2]
end
@testset "rot!" begin
x = convert(Vector{elty}, randn(elty, n))
y = convert(Vector{elty}, randn(elty, n))
x = randn(elty, n)
y = randn(elty, n)
c = rand(real(elty))
for sty in unique!([real(elty), elty])
s = rand(sty)
Expand All @@ -87,8 +79,8 @@ Random.seed!(100)
end
end
@testset "axp(b)y" begin
x1 = convert(Vector{elty}, randn(elty, n))
x2 = convert(Vector{elty}, randn(elty, n))
x1 = randn(elty, n)
x2 = randn(elty, n)
α = rand(elty)
β = rand(elty)
@test BLAS.axpy!(α,copy(x1),copy(x2)) α*x1 + x2
Expand All @@ -111,18 +103,20 @@ Random.seed!(100)
a = rand(elty,n)
b = view(a,2:2:n,1)
@test BLAS.nrm2(b) norm(b)
@test BLAS.asum(b) sum(x -> abs(real(x)) + abs(imag(x)), b)
@test BLAS.iamax(b) == findmax(x -> abs(real(x)) + abs(imag(x)), b)[2]
@test BLAS.asum(b) sum(fabs, b)
@test BLAS.iamax(b) == findmax(fabs, b)[2]
# negative stride test
c = view(a,n:-2:2)
@test BLAS.nrm2(c) norm(c)
@test BLAS.asum(c) sum(x -> abs(real(x)) + abs(imag(x)), c)
@test BLAS.asum(c) sum(fabs, c)
@test BLAS.iamax(c) == 0
end
# scal
α = rand(elty)
a = rand(elty,n)
@test BLAS.scal(n,α,a,1) α * a
# negative stride test
@test BLAS.scal!(α, view(copy(a), n:-1:1)) α * reverse(a)

@testset "trsv" begin
A = triu(rand(elty,n,n))
Expand Down Expand Up @@ -155,15 +149,13 @@ Random.seed!(100)
end
end
@testset "copy" begin
x1 = convert(Vector{elty}, randn(n))
x2 = convert(Vector{elty}, randn(n))
BLAS.copyto!(x2, 1:n, x1, 1:n)
@test x2 == x1
x1 = randn(elty, n)
x2 = randn(elty, n)
@test x2 === BLAS.copyto!(x2, 1:n, x1, 1:n) == x1
@test_throws DimensionMismatch BLAS.copyto!(x2, 1:n, x1, 1:(n - 1))
@test_throws ArgumentError BLAS.copyto!(x1, 0:div(n, 2), x2, 1:(div(n, 2) + 1))
@test_throws ArgumentError BLAS.copyto!(x1, 1:(div(n, 2) + 1), x2, 0:div(n, 2))
BLAS.copyto!(x2, 1:n, x1, n:-1:1)
@test x2 == reverse(x1)
@test x2 === BLAS.copyto!(x2, 1:n, x1, n:-1:1) == reverse(x1)
end
# trmv
A = triu(rand(elty,n,n))
Expand Down Expand Up @@ -550,7 +542,7 @@ end

M = fill(elty(1.0), 3, 3)
@test BLAS.scal!(elty(2), view(M,:,2)) === view(M,:,2)
@test BLAS.scal!(elty(3), view(M,3,3:-1:1)) === view(M,3,3:-1:1)
@test BLAS.scal!(elty(3), view(M,3,:)) === view(M,3,:)
@test M == elty[1. 2. 1.; 1. 2. 1.; 3. 6. 3.]
# Level 2
A = WrappedArray(elty[1 2; 3 4])
Expand Down

0 comments on commit b8f4ab7

Please sign in to comment.