Skip to content

Commit

Permalink
Move more logic to defn
Browse files Browse the repository at this point in the history
  • Loading branch information
jonatanklosko committed Feb 16, 2024
1 parent 034764d commit b1f35e1
Showing 1 changed file with 23 additions and 22 deletions.
45 changes: 23 additions & 22 deletions lib/axon/layers.ex
Original file line number Diff line number Diff line change
Expand Up @@ -2041,34 +2041,35 @@ defmodule Axon.Layers do
deftransformp resize_with_kernel(input, out_shape, spatial_axes, kernel_fun) do
for axis <- spatial_axes, reduce: input do
input ->
input_shape = Nx.shape(input)
input_size = elem(input_shape, axis)
output_size = elem(out_shape, axis)
resize_axis_with_kernel(input,
axis: axis,
output_size: elem(out_shape, axis),
kernel_fun: kernel_fun
)
end
end

inv_scale = input_size / output_size
kernel_scale = Nx.max(1, inv_scale)
defnp resize_axis_with_kernel(input, opts) do
axis = opts[:axis]
output_size = opts[:output_size]
kernel_fun = opts[:kernel_fun]

sample_f =
Nx.add(Nx.iota({1, output_size}), 0.5)
|> Nx.multiply(inv_scale)
|> Nx.subtract(0.5)
input_size = Nx.axis_size(input, axis)

x = Nx.abs(Nx.subtract(sample_f, Nx.iota({input_size, 1}))) |> Nx.divide(kernel_scale)
weights = kernel_fun.(x)
inv_scale = input_size / output_size
kernel_scale = max(1, inv_scale)

weights_sum = Nx.sum(weights, axes: [0], keep_axes: true)
sample_f = (Nx.iota({1, output_size}) + 0.5) * inv_scale - 0.5
x = Nx.abs(sample_f - Nx.iota({input_size, 1})) / kernel_scale
weights = kernel_fun.(x)

weights =
Nx.select(
Nx.greater(Nx.abs(weights), 1000 * @f32_eps),
safe_divide(weights, weights_sum),
0
)
weights_sum = Nx.sum(weights, axes: [0], keep_axes: true)

input = Nx.dot(input, [axis], weights, [0])
# The transformed axis is moved to the end, so we transpose back
reorder_axis(input, -1, axis)
end
weights = Nx.select(Nx.abs(weights) > 1000 * @f32_eps, safe_divide(weights, weights_sum), 0)

input = Nx.dot(input, [axis], weights, [0])
# The transformed axis is moved to the end, so we transpose back
reorder_axis(input, -1, axis)
end

defnp fill_linear_kernel(x) do
Expand Down

0 comments on commit b1f35e1

Please sign in to comment.