-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathlcn.py
61 lines (47 loc) · 2.03 KB
/
lcn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import torch
import torch.nn.functional as F
import numpy as np
import PIL.Image as pim
import matplotlib.pyplot as plt
import cv2
import scipy.misc
def LocalContrastNorm(image,radius=9):
"""
image: torch.Tensor , .shape => (1,channels,height,width)
radius: Gaussian filter size (int), odd
"""
if radius%2 == 0:
radius += 1
def get_gaussian_filter(kernel_shape):
x = np.zeros(kernel_shape, dtype='float32')
def gauss(x, y, sigma=2.0):
Z = 2 * np.pi * sigma ** 2
return 1. / Z * np.exp(-(x ** 2 + y ** 2) / (2. * sigma ** 2))
mid = np.floor(kernel_shape[-1] / 2.)
for kernel_idx in range(0, kernel_shape[1]):
for i in range(0, kernel_shape[2]):
for j in range(0, kernel_shape[3]):
x[0, kernel_idx, i, j] = gauss(i - mid, j - mid)
return x / np.sum(x)
n,c,h,w = image.shape[0],image.shape[1],image.shape[2],image.shape[3]
gaussian_filter = torch.Tensor(get_gaussian_filter((1,c,radius,radius)))
filtered_out = F.conv2d(image,gaussian_filter,padding=radius-1)
mid = int(np.floor(gaussian_filter.shape[2] / 2.))
### Subtractive Normalization
centered_image = image - filtered_out[:,:,mid:-mid,mid:-mid]
## Variance Calc
sum_sqr_image = F.conv2d(centered_image.pow(2),gaussian_filter,padding=radius-1)
s_deviation = sum_sqr_image[:,:,mid:-mid,mid:-mid].sqrt()
per_img_mean = s_deviation.mean()
## Divisive Normalization
divisor = np.maximum(per_img_mean.numpy(),s_deviation.numpy())
divisor = np.maximum(divisor, 1e-4)
new_image = centered_image / torch.Tensor(divisor)
return new_image
image = plt.imread('tulsi.ppm')
image_tensor = torch.Tensor([np.array(image).transpose((2,0,1))])
print(image_tensor.shape)
ret = LocalContrastNorm(image_tensor,radius=9)
ret = ret[0].numpy().transpose((1,2,0))
scaled_ret = (ret - ret.min())/(ret.max() - ret.min()) ## Scaled between 0 to 1 to see properly
scipy.misc.imsave('outfile.jpg', scaled_ret)