Skip to content

Commit

Permalink
Merge pull request #364 from boathit/master
Browse files Browse the repository at this point in the history
fix argmax and add test
  • Loading branch information
MikeInnes authored Aug 23, 2018
2 parents dfe7578 + dcde6d2 commit 953280d
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 10 deletions.
12 changes: 6 additions & 6 deletions docs/src/data/onehot.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
It's common to encode categorical variables (like `true`, `false` or `cat`, `dog`) in "one-of-k" or ["one-hot"](https://en.wikipedia.org/wiki/One-hot) form. Flux provides the `onehot` function to make this easy.

```
julia> using Flux: onehot
julia> using Flux: onehot, onecold
julia> onehot(:b, [:a, :b, :c])
3-element Flux.OneHotVector:
Expand All @@ -18,22 +18,22 @@ julia> onehot(:c, [:a, :b, :c])
true
```

The inverse is `argmax` (which can take a general probability distribution, as well as just booleans).
The inverse is `onecold` (which can take a general probability distribution, as well as just booleans).

```julia
julia> argmax(ans, [:a, :b, :c])
julia> onecold(ans, [:a, :b, :c])
:c

julia> argmax([true, false, false], [:a, :b, :c])
julia> onecold([true, false, false], [:a, :b, :c])
:a

julia> argmax([0.3, 0.2, 0.5], [:a, :b, :c])
julia> onecold([0.3, 0.2, 0.5], [:a, :b, :c])
:c
```

## Batches

`onehotbatch` creates a batch (matrix) of one-hot vectors, and `argmax` treats matrices as batches.
`onehotbatch` creates a batch (matrix) of one-hot vectors, and `onecold` treats matrices as batches.

```julia
julia> using Flux: onehotbatch
Expand Down
12 changes: 8 additions & 4 deletions src/onehot.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,15 @@ end
onehotbatch(ls, labels, unk...) =
OneHotMatrix(length(labels), [onehot(l, labels, unk...) for l in ls])

argmax(y::AbstractVector, labels = 1:length(y)) =
labels[findfirst(y, maximum(y))]
onecold(y::AbstractVector, labels = 1:length(y)) = labels[Base.argmax(y)]

argmax(y::AbstractMatrix, l...) =
squeeze(mapslices(y -> argmax(y, l...), y, 1), 1)
onecold(y::AbstractMatrix, labels...) =
dropdims(mapslices(y -> onecold(y, labels...), y, dims=1), dims=1)

function argmax(xs...)
Base.depwarn("`argmax(...) is deprecated, use `onecold(...)` instead.", :argmax)
return onecold(xs...)
end

# Ambiguity hack

Expand Down
13 changes: 13 additions & 0 deletions test/onehot.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
using Flux:onecold
using Test

@testset "onecold" begin
a = [1, 2, 5, 3.]
A = [1 20 5; 2 7 6; 3 9 10; 2 1 14]
labels = ['A', 'B', 'C', 'D']

@test onecold(a) == 3
@test onecold(A) == [3, 1, 4]
@test onecold(a, labels) == 'C'
@test onecold(A, labels) == ['C', 'A', 'D']
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ insert!(LOAD_PATH, 2, "@v#.#")
@testset "Flux" begin

include("utils.jl")
include("onehot.jl")
include("tracker.jl")
include("layers/normalisation.jl")
include("layers/stateless.jl")
Expand Down

0 comments on commit 953280d

Please sign in to comment.