-
Notifications
You must be signed in to change notification settings - Fork 62
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
0a12e3e
commit 924d9f2
Showing
18 changed files
with
1,611 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
seed: 0 | ||
mixed_precision: false | ||
base_lr: 4.0e-4 | ||
|
||
nr_gpus: 8 | ||
batch_size_single: 2 | ||
n_total_epoch: 600 | ||
minibatch_per_epoch: 500 | ||
|
||
loadmodel: ~ | ||
log_dir: "./train_log" | ||
model_save_freq_epoch: 1 | ||
|
||
max_disp: 256 | ||
image_width: 512 | ||
image_height: 384 | ||
training_data_path: "./stereo_trainset/crestereo" | ||
|
||
log_level: "logging.INFO" | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,215 @@ | ||
import os | ||
import cv2 | ||
import glob | ||
import numpy as np | ||
from PIL import Image, ImageEnhance | ||
|
||
from megengine.data.dataset import Dataset | ||
|
||
|
||
class Augmentor: | ||
def __init__( | ||
self, | ||
image_height=384, | ||
image_width=512, | ||
max_disp=256, | ||
scale_min=0.6, | ||
scale_max=1.0, | ||
seed=0, | ||
): | ||
super().__init__() | ||
self.image_height = image_height | ||
self.image_width = image_width | ||
self.max_disp = max_disp | ||
self.scale_min = scale_min | ||
self.scale_max = scale_max | ||
self.rng = np.random.RandomState(seed) | ||
|
||
def chromatic_augmentation(self, img): | ||
random_brightness = np.random.uniform(0.8, 1.2) | ||
random_contrast = np.random.uniform(0.8, 1.2) | ||
random_gamma = np.random.uniform(0.8, 1.2) | ||
|
||
img = Image.fromarray(img) | ||
|
||
enhancer = ImageEnhance.Brightness(img) | ||
img = enhancer.enhance(random_brightness) | ||
enhancer = ImageEnhance.Contrast(img) | ||
img = enhancer.enhance(random_contrast) | ||
|
||
gamma_map = [ | ||
255 * 1.0 * pow(ele / 255.0, random_gamma) for ele in range(256) | ||
] * 3 | ||
img = img.point(gamma_map) # use PIL's point-function to accelerate this part | ||
|
||
img_ = np.array(img) | ||
|
||
return img_ | ||
|
||
def __call__(self, left_img, right_img, left_disp): | ||
# 1. chromatic augmentation | ||
left_img = self.chromatic_augmentation(left_img) | ||
right_img = self.chromatic_augmentation(right_img) | ||
|
||
# 2. spatial augmentation | ||
# 2.1) rotate & vertical shift for right image | ||
if self.rng.binomial(1, 0.5): | ||
angle, pixel = 0.1, 2 | ||
px = self.rng.uniform(-pixel, pixel) | ||
ag = self.rng.uniform(-angle, angle) | ||
image_center = ( | ||
self.rng.uniform(0, right_img.shape[0]), | ||
self.rng.uniform(0, right_img.shape[1]), | ||
) | ||
rot_mat = cv2.getRotationMatrix2D(image_center, ag, 1.0) | ||
right_img = cv2.warpAffine( | ||
right_img, rot_mat, right_img.shape[1::-1], flags=cv2.INTER_LINEAR | ||
) | ||
trans_mat = np.float32([[1, 0, 0], [0, 1, px]]) | ||
right_img = cv2.warpAffine( | ||
right_img, trans_mat, right_img.shape[1::-1], flags=cv2.INTER_LINEAR | ||
) | ||
|
||
# 2.2) random resize | ||
resize_scale = self.rng.uniform(self.scale_min, self.scale_max) | ||
|
||
left_img = cv2.resize( | ||
left_img, | ||
None, | ||
fx=resize_scale, | ||
fy=resize_scale, | ||
interpolation=cv2.INTER_LINEAR, | ||
) | ||
right_img = cv2.resize( | ||
right_img, | ||
None, | ||
fx=resize_scale, | ||
fy=resize_scale, | ||
interpolation=cv2.INTER_LINEAR, | ||
) | ||
|
||
disp_mask = (left_disp < float(self.max_disp / resize_scale)) & (left_disp > 0) | ||
disp_mask = disp_mask.astype("float32") | ||
disp_mask = cv2.resize( | ||
disp_mask, | ||
None, | ||
fx=resize_scale, | ||
fy=resize_scale, | ||
interpolation=cv2.INTER_LINEAR, | ||
) | ||
|
||
left_disp = ( | ||
cv2.resize( | ||
left_disp, | ||
None, | ||
fx=resize_scale, | ||
fy=resize_scale, | ||
interpolation=cv2.INTER_LINEAR, | ||
) | ||
* resize_scale | ||
) | ||
|
||
# 2.3) random crop | ||
h, w, c = left_img.shape | ||
dx = w - self.image_width | ||
dy = h - self.image_height | ||
dy = self.rng.randint(min(0, dy), max(0, dy) + 1) | ||
dx = self.rng.randint(min(0, dx), max(0, dx) + 1) | ||
|
||
M = np.float32([[1.0, 0.0, -dx], [0.0, 1.0, -dy]]) | ||
left_img = cv2.warpAffine( | ||
left_img, | ||
M, | ||
(self.image_width, self.image_height), | ||
flags=cv2.INTER_LINEAR, | ||
borderValue=0, | ||
) | ||
right_img = cv2.warpAffine( | ||
right_img, | ||
M, | ||
(self.image_width, self.image_height), | ||
flags=cv2.INTER_LINEAR, | ||
borderValue=0, | ||
) | ||
left_disp = cv2.warpAffine( | ||
left_disp, | ||
M, | ||
(self.image_width, self.image_height), | ||
flags=cv2.INTER_LINEAR, | ||
borderValue=0, | ||
) | ||
disp_mask = cv2.warpAffine( | ||
disp_mask, | ||
M, | ||
(self.image_width, self.image_height), | ||
flags=cv2.INTER_LINEAR, | ||
borderValue=0, | ||
) | ||
|
||
# 3. add random occlusion to right image | ||
if self.rng.binomial(1, 0.5): | ||
sx = int(self.rng.uniform(50, 100)) | ||
sy = int(self.rng.uniform(50, 100)) | ||
cx = int(self.rng.uniform(sx, right_img.shape[0] - sx)) | ||
cy = int(self.rng.uniform(sy, right_img.shape[1] - sy)) | ||
right_img[cx - sx : cx + sx, cy - sy : cy + sy] = np.mean( | ||
np.mean(right_img, 0), 0 | ||
)[np.newaxis, np.newaxis] | ||
|
||
return left_img, right_img, left_disp, disp_mask | ||
|
||
|
||
class CREStereoDataset(Dataset): | ||
def __init__(self, root): | ||
super().__init__() | ||
self.imgs = glob.glob(os.path.join(root, "**/*_left.jpg"), recursive=True) | ||
self.augmentor = Augmentor( | ||
image_height=384, | ||
image_width=512, | ||
max_disp=256, | ||
scale_min=0.6, | ||
scale_max=1.0, | ||
seed=0, | ||
) | ||
self.rng = np.random.RandomState(0) | ||
|
||
def get_disp(self, path): | ||
disp = cv2.imread(path, cv2.IMREAD_UNCHANGED) | ||
return disp.astype(np.float32) / 32 | ||
|
||
def __getitem__(self, index): | ||
# find path | ||
left_path = self.imgs[index] | ||
prefix = left_path[: left_path.rfind("_")] | ||
right_path = prefix + "_right.jpg" | ||
left_disp_path = prefix + "_left.disp.png" | ||
right_disp_path = prefix + "_right.disp.png" | ||
|
||
# read img, disp | ||
left_img = cv2.imread(left_path, cv2.IMREAD_COLOR) | ||
right_img = cv2.imread(right_path, cv2.IMREAD_COLOR) | ||
left_disp = self.get_disp(left_disp_path) | ||
right_disp = self.get_disp(right_disp_path) | ||
|
||
if self.rng.binomial(1, 0.5): | ||
left_img, right_img = np.fliplr(right_img), np.fliplr(left_img) | ||
left_disp, right_disp = np.fliplr(right_disp), np.fliplr(left_disp) | ||
left_disp[left_disp == np.inf] = 0 | ||
|
||
# augmentaion | ||
left_img, right_img, left_disp, disp_mask = self.augmentor( | ||
left_img, right_img, left_disp | ||
) | ||
|
||
left_img = left_img.transpose(2, 0, 1).astype("uint8") | ||
right_img = right_img.transpose(2, 0, 1).astype("uint8") | ||
|
||
return { | ||
"left": left_img, | ||
"right": right_img, | ||
"disparity": left_disp, | ||
"mask": disp_mask, | ||
} | ||
|
||
def __len__(self): | ||
return len(self.imgs) |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .crestereo import CREStereo as Model |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
from .transformer import LocalFeatureTransformer | ||
from .position_encoding import PositionEncodingSine |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
""" | ||
Linear Transformer proposed in "Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention" | ||
Modified from: https://github.com/idiap/fast-transformers/blob/master/fast_transformers/attention/linear_attention.py | ||
""" | ||
|
||
import numpy as np | ||
import megengine.module as M | ||
import megengine.functional as F | ||
|
||
|
||
def elu(x, alpha=1.0): | ||
return F.maximum(0, x) + F.minimum(0, alpha * (F.exp(x) - 1)) | ||
|
||
|
||
def elu_feature_map(x): | ||
return elu(x) + 1 | ||
|
||
|
||
class LinearAttention(M.Module): | ||
def __init__(self, eps=1e-6): | ||
super().__init__() | ||
self.feature_map = elu_feature_map | ||
self.eps = eps | ||
|
||
def forward(self, queries, keys, values, q_mask=None, kv_mask=None): | ||
"""Multi-Head linear attention proposed in "Transformers are RNNs" | ||
Args: | ||
queries: [N, L, H, D] | ||
keys: [N, S, H, D] | ||
values: [N, S, H, D] | ||
q_mask: [N, L] | ||
kv_mask: [N, S] | ||
Returns: | ||
queried_values: (N, L, H, D) | ||
""" | ||
Q = self.feature_map(queries) | ||
K = self.feature_map(keys) | ||
|
||
# set padded position to zero | ||
if q_mask is not None: | ||
Q = Q * F.expand_dims(q_mask, (2, 3)) # [:, :, None, None] | ||
if kv_mask is not None: | ||
K = K * F.expand_dims(kv_mask, (2, 3)) # [:, :, None, None] | ||
values = values * F.expand_dims(kv_mask, (2, 3)) # [:, :, None, None] | ||
|
||
v_length = values.shape[1] | ||
values = values / v_length # prevent fp16 overflow | ||
KV = F.sum(F.expand_dims(K, -1) * F.expand_dims(values, 3), axis=1) | ||
Z = 1 / (F.sum(Q * F.sum(K, axis=1, keepdims=True), axis=-1) + self.eps) | ||
queried_values = ( | ||
F.sum( | ||
F.expand_dims(Q, -1) * F.expand_dims(KV, 1) * F.expand_dims(Z, (3, 4)), | ||
axis=3, | ||
) | ||
* v_length | ||
) | ||
|
||
return queried_values | ||
|
||
|
||
class FullAttention(M.Module): | ||
def __init__(self, use_dropout=False, attention_dropout=0.1): | ||
super().__init__() | ||
self.use_dropout = use_dropout | ||
self.dropout = M.Dropout(drop_prob=attention_dropout) | ||
|
||
def forward(self, queries, keys, values, q_mask=None, kv_mask=None): | ||
"""Multi-head scaled dot-product attention, a.k.a full attention. | ||
Args: | ||
queries: [N, L, H, D] | ||
keys: [N, S, H, D] | ||
values: [N, S, H, D] | ||
q_mask: [N, L] | ||
kv_mask: [N, S] | ||
Returns: | ||
queried_values: (N, L, H, D) | ||
""" | ||
|
||
# Compute the unnormalized attention and apply the masks | ||
QK = F.sum(F.expand_dims(queries, 2) * F.expand_dims(keys, 1), axis=-1) | ||
if kv_mask is not None: | ||
assert q_mask.dtype == np.bool_ | ||
assert kv_mask.dtype == np.bool_ | ||
QK[ | ||
~(F.expand_dims(q_mask, (2, 3)) & F.expand_dims(kv_mask, (1, 3))) | ||
] = float("-inf") | ||
|
||
# Compute the attention and the weighted average | ||
softmax_temp = 1.0 / queries.shape[3] ** 0.5 # sqrt(D) | ||
A = F.softmax(softmax_temp * QK, axis=2) | ||
if self.use_dropout: | ||
A = self.dropout(A) | ||
|
||
queried_values = F.sum(F.expand_dims(A, -1) * F.expand_dims(values, 1), axis=2) | ||
|
||
return queried_values |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
import math | ||
import megengine.module as M | ||
import megengine.functional as F | ||
|
||
|
||
class PositionEncodingSine(M.Module): | ||
""" | ||
This is a sinusoidal position encoding that generalized to 2-dimensional images | ||
""" | ||
|
||
def __init__(self, d_model, max_shape=(256, 256)): | ||
""" | ||
Args: | ||
max_shape (tuple): for 1/8 featmap, the max length of 256 corresponds to 2048 pixels | ||
""" | ||
super().__init__() | ||
|
||
pe = F.zeros((d_model, *max_shape)) | ||
y_position = F.expand_dims(F.cumsum(F.ones(max_shape), 0), 0) | ||
x_position = F.expand_dims(F.cumsum(F.ones(max_shape), 1), 0) | ||
div_term = F.exp( | ||
F.arange(0, d_model // 2, 2) * (-math.log(10000.0) / d_model // 2) | ||
) | ||
div_term = F.expand_dims(div_term, (1, 2)) # [C//4, 1, 1] | ||
pe[0::4, :, :] = F.sin(x_position * div_term) | ||
pe[1::4, :, :] = F.cos(x_position * div_term) | ||
pe[2::4, :, :] = F.sin(y_position * div_term) | ||
pe[3::4, :, :] = F.cos(y_position * div_term) | ||
|
||
self.pe = F.expand_dims(pe, 0) | ||
|
||
def forward(self, x): | ||
""" | ||
Args: | ||
x: [N, C, H, W] | ||
""" | ||
return x + self.pe[:, :, : x.shape[2], : x.shape[3]].to(x.device) |
Oops, something went wrong.