-
Notifications
You must be signed in to change notification settings - Fork 0
/
pool_forward.py
60 lines (45 loc) · 2.46 KB
/
pool_forward.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import numpy as np
def pool_forward(A_prev, hparameters, mode = "max"):
"""
Implements the forward pass of the pooling layer
Arguments:
A_prev -- Input data, numpy array of shape (m, n_H_prev, n_W_prev, n_C_prev)
hparameters -- python dictionary containing "f" and "stride"
mode -- the pooling mode you would like to use, defined as a string ("max" or "average")
Returns:
A -- output of the pool layer, a numpy array of shape (m, n_H, n_W, n_C)
cache -- cache used in the backward pass of the pooling layer, contains the input and hparameters
"""
# Retrieve dimensions from the input shape
(m, n_H_prev, n_W_prev, n_C_prev) = A_prev.shape
# Retrieve hyperparameters from "hparameters"
f = hparameters["f"]
stride = hparameters["stride"]
# Define the dimensions of the output
n_H = int(1 + (n_H_prev - f) / stride)
n_W = int(1 + (n_W_prev - f) / stride)
n_C = n_C_prev
# Initialize output matrix A
A = np.zeros((m, n_H, n_W, n_C))
for i in range(m): # loop over the training examples
for h in range(n_H): # loop on the vertical axis of the output volume
# Find the vertical start and end of the current "slice"
vert_start = h * stride
vert_end = vert_start + f
for w in range(n_W): # loop on the horizontal axis of the output volume
# Find the vertical start and end of the current "slice"
horiz_start = w * stride
horiz_end = horiz_start + f
for c in range (n_C): # loop over the channels of the output volume
# Use the corners to define the current slice on the ith training example of A_prev, channel c.
a_prev_slice = A_prev[i,vert_start:vert_end,horiz_start:horiz_end,c]
# Compute the pooling operation on the slice.
if mode == "max":
A[i, h, w, c] = np.max(a_prev_slice)
elif mode == "average":
A[i, h, w, c] = np.mean(a_prev_slice)
# Store the input and hparameters in "cache" for pool_backward()
cache = (A_prev, hparameters)
# Making sure your output shape is correct
assert(A.shape == (m, n_H, n_W, n_C))
return A, cache