Open the src/input_opt.py
file. The network ./data/weights.pkl
contains network weights pre-trained on MNIST. Turn the network optimization problem around, and find an input that makes a particular output neuron extremely happy. In other words maximize,
Use jax.value_and_grad
to find the gradients of the network input jax.random.uniform
network input of shape [1, 28, 28, 1]
and
iteratively optimize it.
Reuse your MNIST digit recognition code. Implement IG as discussed in the lecture. Recall the equation
F partial xi denotes the gradients with respect to the input color-channels i. x prime denotes a baseline black image. And x symbolizes an input we are interested in. Finally, m denotes the number of summation steps from the black baseline image to the interesting input.
Follow the todos in ./src/mnist_integrated
.