Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Vectorize InstanceNormalization and BatchNormalization #469

Merged
merged 4 commits into from
Dec 19, 2024

Conversation

robertknight
Copy link
Owner

@robertknight robertknight commented Dec 19, 2024

InstanceNormalization has two steps:

  1. Compute channel mean and variance
  2. Shift and scale result to normalize the mean and variance, and then apply a per-channel scale and bias

BatchNormalization is similar but uses precomputed values for step 1. Previously part of step 1 was vectorized, but not the variance calculation. Add a vectorized kernel for step 2 and use it for BatchNormalization and InstanceNormalization.

Tested using the wav2vec example on x64, this made InstanceNormalization ~3x faster (~26ms -> ~8ms per run).

This enables printing a debug representation of a vector implementing
`Simd` via `println!("{:?}", simd_vec.to_array())`.
Add two vectorized functions that will be useful as part of the
InstanceNormalization operation.

 - `vec_sum_square_sub` is like `vec_sum_square` but subtracts a constant from
   each element before squaring.

 - `vec_shift_scale_bias` subtracts a constant from each element and then shifts
   and scales the result.
…tion

InstanceNormalization has three steps:

 1. Compute channel mean
 2. Compute channel variance
 3. Shift and scale result to normalize the mean and variance, and then
    apply a per-channel scale and bias

Previously only step 1 was vectorized. This vectorizes steps 2 and 3 as well.

Tested using the wav2vec example on x64, this made InstanceNormalization ~3x
faster (~26ms -> ~8ms per run).
BatchNormalization and InstanceNormalization are very similar, except
BatchNormalization uses pre-computed mean and variance statistics while
InstanceNormalization computes them dynamically.

Extract the vectorized normalization step from `instance_normalization` and
re-use it for `batch_normalization`.
@robertknight robertknight changed the title Vectorize variance calculation and input scaling in InstanceNormalization Vectorize InstanceNormalization and BatchNormalization Dec 19, 2024
@robertknight robertknight merged commit 1fd414f into main Dec 19, 2024
2 checks passed
@robertknight robertknight deleted the vec-instance-norm branch December 19, 2024 09:01
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant