Skip to content

Commit

Permalink
Monocular depth estimation - Keras 3 Migration (Only Tensorflow Backend)
Browse files Browse the repository at this point in the history
This PR changes the Monocular depth estimation example to keras 3.0 [TF-Only Example] as requested in [KerasCV-Fixit](keras-team/keras-cv#2211)

Please find [gist](https://colab.sandbox.google.com/gist/chunduriv/994b36a97985e44d9573436b987993f9/tf_depth_estimation.ipynb).
  • Loading branch information
chunduriv authored Aug 13, 2024
1 parent 2ac94c4 commit 1bf5659
Showing 1 changed file with 28 additions and 26 deletions.
54 changes: 28 additions & 26 deletions examples/vision/depth_estimation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
Title: Monocular depth estimation
Author: [Victor Basu](https://www.linkedin.com/in/victor-basu-520958147)
Date created: 2021/08/30
Last modified: 2021/08/30
Last modified: 2024/08/13
Description: Implement a depth estimation model with a convnet.
Accelerator: GPU
"""
Expand All @@ -27,15 +27,17 @@
import os
import sys

import tensorflow as tf
from tensorflow.keras import layers

import keras
from keras import layers
from keras import ops
import tensorflow as tf
import pandas as pd
import numpy as np
import cv2
import matplotlib.pyplot as plt

tf.random.set_seed(123)
keras.utils.set_random_seed(123)

"""
## Downloading the dataset
Expand All @@ -52,7 +54,7 @@

annotation_folder = "/dataset/"
if not os.path.exists(os.path.abspath(".") + annotation_folder):
annotation_zip = tf.keras.utils.get_file(
annotation_zip = keras.utils.get_file(
"val.tar.gz",
cache_subdir=os.path.abspath("."),
origin="http://diode-dataset.s3.amazonaws.com/val.tar.gz",
Expand Down Expand Up @@ -105,7 +107,7 @@
"""


class DataGenerator(tf.keras.utils.Sequence):
class DataGenerator(keras.utils.Sequence):
def __init__(self, data, batch_size=6, dim=(768, 1024), n_channels=3, shuffle=True):
"""
Initialization
Expand Down Expand Up @@ -178,7 +180,7 @@ def data_generation(self, batch):
self.data["depth"][batch_id],
self.data["mask"][batch_id],
)

x, y = x.astype("float32"), y.astype("float32")
return x, y


Expand Down Expand Up @@ -249,10 +251,10 @@ def __init__(
super().__init__(**kwargs)
self.convA = layers.Conv2D(filters, kernel_size, strides, padding)
self.convB = layers.Conv2D(filters, kernel_size, strides, padding)
self.reluA = layers.LeakyReLU(alpha=0.2)
self.reluB = layers.LeakyReLU(alpha=0.2)
self.bn2a = tf.keras.layers.BatchNormalization()
self.bn2b = tf.keras.layers.BatchNormalization()
self.reluA = layers.LeakyReLU(negative_slope=0.2)
self.reluB = layers.LeakyReLU(negative_slope=0.2)
self.bn2a = layers.BatchNormalization()
self.bn2b = layers.BatchNormalization()

self.pool = layers.MaxPool2D((2, 2), (2, 2))

Expand All @@ -278,11 +280,11 @@ def __init__(
self.us = layers.UpSampling2D((2, 2))
self.convA = layers.Conv2D(filters, kernel_size, strides, padding)
self.convB = layers.Conv2D(filters, kernel_size, strides, padding)
self.reluA = layers.LeakyReLU(alpha=0.2)
self.reluB = layers.LeakyReLU(alpha=0.2)
self.bn2a = tf.keras.layers.BatchNormalization()
self.bn2b = tf.keras.layers.BatchNormalization()
self.conc = layers.Concatenate()
self.reluA = layers.LeakyReLU(negative_slope=0.2)
self.reluB = layers.LeakyReLU(negative_slope=0.2)
self.bn2a = layers.BatchNormalization()
self.bn2b = layers.BatchNormalization()
self.conc = layers.Concatenate()

def call(self, x, skip):
x = self.us(x)
Expand All @@ -305,8 +307,8 @@ def __init__(
super().__init__(**kwargs)
self.convA = layers.Conv2D(filters, kernel_size, strides, padding)
self.convB = layers.Conv2D(filters, kernel_size, strides, padding)
self.reluA = layers.LeakyReLU(alpha=0.2)
self.reluB = layers.LeakyReLU(alpha=0.2)
self.reluA = layers.LeakyReLU(negative_slope=0.2)
self.reluB = layers.LeakyReLU(negative_slope=0.2)

def call(self, x):
x = self.convA(x)
Expand All @@ -328,13 +330,13 @@ def call(self, x):
"""


class DepthEstimationModel(tf.keras.Model):
class DepthEstimationModel(keras.Model):
def __init__(self):
super().__init__()
self.ssim_loss_weight = 0.85
self.l1_loss_weight = 0.1
self.edge_loss_weight = 0.9
self.loss_metric = tf.keras.metrics.Mean(name="loss")
self.loss_metric = keras.metrics.Mean(name="loss")
f = [16, 32, 64, 128, 256]
self.downscale_blocks = [
DownscaleBlock(f[0]),
Expand All @@ -355,26 +357,26 @@ def calculate_loss(self, target, pred):
# Edges
dy_true, dx_true = tf.image.image_gradients(target)
dy_pred, dx_pred = tf.image.image_gradients(pred)
weights_x = tf.exp(tf.reduce_mean(tf.abs(dx_true)))
weights_y = tf.exp(tf.reduce_mean(tf.abs(dy_true)))
weights_x = ops.cast(ops.exp(ops.mean(ops.abs(dx_true))),"float32")
weights_y = ops.cast(ops.exp(ops.mean(ops.abs(dy_true))),"float32")

# Depth smoothness
smoothness_x = dx_pred * weights_x
smoothness_y = dy_pred * weights_y

depth_smoothness_loss = tf.reduce_mean(abs(smoothness_x)) + tf.reduce_mean(
depth_smoothness_loss = ops.mean(abs(smoothness_x)) + ops.mean(
abs(smoothness_y)
)

# Structural similarity (SSIM) index
ssim_loss = tf.reduce_mean(
ssim_loss = ops.mean(
1
- tf.image.ssim(
target, pred, max_val=WIDTH, filter_size=7, k1=0.01**2, k2=0.03**2
)
)
# Point-wise depth
l1_loss = tf.reduce_mean(tf.abs(target - pred))
l1_loss = ops.mean(ops.abs(target - pred))

loss = (
(self.ssim_loss_weight * ssim_loss)
Expand Down Expand Up @@ -432,7 +434,7 @@ def call(self, x):
## Model training
"""

optimizer = tf.keras.optimizers.Adam(
optimizer = keras.optimizers.Adam(
learning_rate=LR,
amsgrad=False,
)
Expand Down

0 comments on commit 1bf5659

Please sign in to comment.