Skip to content

Commit

Permalink
code released
Browse files Browse the repository at this point in the history
  • Loading branch information
JackLee396 committed Mar 30, 2022
1 parent 0a12e3e commit 924d9f2
Show file tree
Hide file tree
Showing 18 changed files with 1,611 additions and 0 deletions.
20 changes: 20 additions & 0 deletions cfgs/train.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
seed: 0
mixed_precision: false
base_lr: 4.0e-4

nr_gpus: 8
batch_size_single: 2
n_total_epoch: 600
minibatch_per_epoch: 500

loadmodel: ~
log_dir: "./train_log"
model_save_freq_epoch: 1

max_disp: 256
image_width: 512
image_height: 384
training_data_path: "./stereo_trainset/crestereo"

log_level: "logging.INFO"

215 changes: 215 additions & 0 deletions dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,215 @@
import os
import cv2
import glob
import numpy as np
from PIL import Image, ImageEnhance

from megengine.data.dataset import Dataset


class Augmentor:
def __init__(
self,
image_height=384,
image_width=512,
max_disp=256,
scale_min=0.6,
scale_max=1.0,
seed=0,
):
super().__init__()
self.image_height = image_height
self.image_width = image_width
self.max_disp = max_disp
self.scale_min = scale_min
self.scale_max = scale_max
self.rng = np.random.RandomState(seed)

def chromatic_augmentation(self, img):
random_brightness = np.random.uniform(0.8, 1.2)
random_contrast = np.random.uniform(0.8, 1.2)
random_gamma = np.random.uniform(0.8, 1.2)

img = Image.fromarray(img)

enhancer = ImageEnhance.Brightness(img)
img = enhancer.enhance(random_brightness)
enhancer = ImageEnhance.Contrast(img)
img = enhancer.enhance(random_contrast)

gamma_map = [
255 * 1.0 * pow(ele / 255.0, random_gamma) for ele in range(256)
] * 3
img = img.point(gamma_map) # use PIL's point-function to accelerate this part

img_ = np.array(img)

return img_

def __call__(self, left_img, right_img, left_disp):
# 1. chromatic augmentation
left_img = self.chromatic_augmentation(left_img)
right_img = self.chromatic_augmentation(right_img)

# 2. spatial augmentation
# 2.1) rotate & vertical shift for right image
if self.rng.binomial(1, 0.5):
angle, pixel = 0.1, 2
px = self.rng.uniform(-pixel, pixel)
ag = self.rng.uniform(-angle, angle)
image_center = (
self.rng.uniform(0, right_img.shape[0]),
self.rng.uniform(0, right_img.shape[1]),
)
rot_mat = cv2.getRotationMatrix2D(image_center, ag, 1.0)
right_img = cv2.warpAffine(
right_img, rot_mat, right_img.shape[1::-1], flags=cv2.INTER_LINEAR
)
trans_mat = np.float32([[1, 0, 0], [0, 1, px]])
right_img = cv2.warpAffine(
right_img, trans_mat, right_img.shape[1::-1], flags=cv2.INTER_LINEAR
)

# 2.2) random resize
resize_scale = self.rng.uniform(self.scale_min, self.scale_max)

left_img = cv2.resize(
left_img,
None,
fx=resize_scale,
fy=resize_scale,
interpolation=cv2.INTER_LINEAR,
)
right_img = cv2.resize(
right_img,
None,
fx=resize_scale,
fy=resize_scale,
interpolation=cv2.INTER_LINEAR,
)

disp_mask = (left_disp < float(self.max_disp / resize_scale)) & (left_disp > 0)
disp_mask = disp_mask.astype("float32")
disp_mask = cv2.resize(
disp_mask,
None,
fx=resize_scale,
fy=resize_scale,
interpolation=cv2.INTER_LINEAR,
)

left_disp = (
cv2.resize(
left_disp,
None,
fx=resize_scale,
fy=resize_scale,
interpolation=cv2.INTER_LINEAR,
)
* resize_scale
)

# 2.3) random crop
h, w, c = left_img.shape
dx = w - self.image_width
dy = h - self.image_height
dy = self.rng.randint(min(0, dy), max(0, dy) + 1)
dx = self.rng.randint(min(0, dx), max(0, dx) + 1)

M = np.float32([[1.0, 0.0, -dx], [0.0, 1.0, -dy]])
left_img = cv2.warpAffine(
left_img,
M,
(self.image_width, self.image_height),
flags=cv2.INTER_LINEAR,
borderValue=0,
)
right_img = cv2.warpAffine(
right_img,
M,
(self.image_width, self.image_height),
flags=cv2.INTER_LINEAR,
borderValue=0,
)
left_disp = cv2.warpAffine(
left_disp,
M,
(self.image_width, self.image_height),
flags=cv2.INTER_LINEAR,
borderValue=0,
)
disp_mask = cv2.warpAffine(
disp_mask,
M,
(self.image_width, self.image_height),
flags=cv2.INTER_LINEAR,
borderValue=0,
)

# 3. add random occlusion to right image
if self.rng.binomial(1, 0.5):
sx = int(self.rng.uniform(50, 100))
sy = int(self.rng.uniform(50, 100))
cx = int(self.rng.uniform(sx, right_img.shape[0] - sx))
cy = int(self.rng.uniform(sy, right_img.shape[1] - sy))
right_img[cx - sx : cx + sx, cy - sy : cy + sy] = np.mean(
np.mean(right_img, 0), 0
)[np.newaxis, np.newaxis]

return left_img, right_img, left_disp, disp_mask


class CREStereoDataset(Dataset):
def __init__(self, root):
super().__init__()
self.imgs = glob.glob(os.path.join(root, "**/*_left.jpg"), recursive=True)
self.augmentor = Augmentor(
image_height=384,
image_width=512,
max_disp=256,
scale_min=0.6,
scale_max=1.0,
seed=0,
)
self.rng = np.random.RandomState(0)

def get_disp(self, path):
disp = cv2.imread(path, cv2.IMREAD_UNCHANGED)
return disp.astype(np.float32) / 32

def __getitem__(self, index):
# find path
left_path = self.imgs[index]
prefix = left_path[: left_path.rfind("_")]
right_path = prefix + "_right.jpg"
left_disp_path = prefix + "_left.disp.png"
right_disp_path = prefix + "_right.disp.png"

# read img, disp
left_img = cv2.imread(left_path, cv2.IMREAD_COLOR)
right_img = cv2.imread(right_path, cv2.IMREAD_COLOR)
left_disp = self.get_disp(left_disp_path)
right_disp = self.get_disp(right_disp_path)

if self.rng.binomial(1, 0.5):
left_img, right_img = np.fliplr(right_img), np.fliplr(left_img)
left_disp, right_disp = np.fliplr(right_disp), np.fliplr(left_disp)
left_disp[left_disp == np.inf] = 0

# augmentaion
left_img, right_img, left_disp, disp_mask = self.augmentor(
left_img, right_img, left_disp
)

left_img = left_img.transpose(2, 0, 1).astype("uint8")
right_img = right_img.transpose(2, 0, 1).astype("uint8")

return {
"left": left_img,
"right": right_img,
"disparity": left_disp,
"mask": disp_mask,
}

def __len__(self):
return len(self.imgs)
Binary file added img/test/left.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added img/test/right.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
1 change: 1 addition & 0 deletions nets/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .crestereo import CREStereo as Model
2 changes: 2 additions & 0 deletions nets/attention/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .transformer import LocalFeatureTransformer
from .position_encoding import PositionEncodingSine
96 changes: 96 additions & 0 deletions nets/attention/linear_attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
"""
Linear Transformer proposed in "Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention"
Modified from: https://github.com/idiap/fast-transformers/blob/master/fast_transformers/attention/linear_attention.py
"""

import numpy as np
import megengine.module as M
import megengine.functional as F


def elu(x, alpha=1.0):
return F.maximum(0, x) + F.minimum(0, alpha * (F.exp(x) - 1))


def elu_feature_map(x):
return elu(x) + 1


class LinearAttention(M.Module):
def __init__(self, eps=1e-6):
super().__init__()
self.feature_map = elu_feature_map
self.eps = eps

def forward(self, queries, keys, values, q_mask=None, kv_mask=None):
"""Multi-Head linear attention proposed in "Transformers are RNNs"
Args:
queries: [N, L, H, D]
keys: [N, S, H, D]
values: [N, S, H, D]
q_mask: [N, L]
kv_mask: [N, S]
Returns:
queried_values: (N, L, H, D)
"""
Q = self.feature_map(queries)
K = self.feature_map(keys)

# set padded position to zero
if q_mask is not None:
Q = Q * F.expand_dims(q_mask, (2, 3)) # [:, :, None, None]
if kv_mask is not None:
K = K * F.expand_dims(kv_mask, (2, 3)) # [:, :, None, None]
values = values * F.expand_dims(kv_mask, (2, 3)) # [:, :, None, None]

v_length = values.shape[1]
values = values / v_length # prevent fp16 overflow
KV = F.sum(F.expand_dims(K, -1) * F.expand_dims(values, 3), axis=1)
Z = 1 / (F.sum(Q * F.sum(K, axis=1, keepdims=True), axis=-1) + self.eps)
queried_values = (
F.sum(
F.expand_dims(Q, -1) * F.expand_dims(KV, 1) * F.expand_dims(Z, (3, 4)),
axis=3,
)
* v_length
)

return queried_values


class FullAttention(M.Module):
def __init__(self, use_dropout=False, attention_dropout=0.1):
super().__init__()
self.use_dropout = use_dropout
self.dropout = M.Dropout(drop_prob=attention_dropout)

def forward(self, queries, keys, values, q_mask=None, kv_mask=None):
"""Multi-head scaled dot-product attention, a.k.a full attention.
Args:
queries: [N, L, H, D]
keys: [N, S, H, D]
values: [N, S, H, D]
q_mask: [N, L]
kv_mask: [N, S]
Returns:
queried_values: (N, L, H, D)
"""

# Compute the unnormalized attention and apply the masks
QK = F.sum(F.expand_dims(queries, 2) * F.expand_dims(keys, 1), axis=-1)
if kv_mask is not None:
assert q_mask.dtype == np.bool_
assert kv_mask.dtype == np.bool_
QK[
~(F.expand_dims(q_mask, (2, 3)) & F.expand_dims(kv_mask, (1, 3)))
] = float("-inf")

# Compute the attention and the weighted average
softmax_temp = 1.0 / queries.shape[3] ** 0.5 # sqrt(D)
A = F.softmax(softmax_temp * QK, axis=2)
if self.use_dropout:
A = self.dropout(A)

queried_values = F.sum(F.expand_dims(A, -1) * F.expand_dims(values, 1), axis=2)

return queried_values
37 changes: 37 additions & 0 deletions nets/attention/position_encoding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import math
import megengine.module as M
import megengine.functional as F


class PositionEncodingSine(M.Module):
"""
This is a sinusoidal position encoding that generalized to 2-dimensional images
"""

def __init__(self, d_model, max_shape=(256, 256)):
"""
Args:
max_shape (tuple): for 1/8 featmap, the max length of 256 corresponds to 2048 pixels
"""
super().__init__()

pe = F.zeros((d_model, *max_shape))
y_position = F.expand_dims(F.cumsum(F.ones(max_shape), 0), 0)
x_position = F.expand_dims(F.cumsum(F.ones(max_shape), 1), 0)
div_term = F.exp(
F.arange(0, d_model // 2, 2) * (-math.log(10000.0) / d_model // 2)
)
div_term = F.expand_dims(div_term, (1, 2)) # [C//4, 1, 1]
pe[0::4, :, :] = F.sin(x_position * div_term)
pe[1::4, :, :] = F.cos(x_position * div_term)
pe[2::4, :, :] = F.sin(y_position * div_term)
pe[3::4, :, :] = F.cos(y_position * div_term)

self.pe = F.expand_dims(pe, 0)

def forward(self, x):
"""
Args:
x: [N, C, H, W]
"""
return x + self.pe[:, :, : x.shape[2], : x.shape[3]].to(x.device)
Loading

0 comments on commit 924d9f2

Please sign in to comment.