|
from ultralytics import YOLO |
|
import cv2 |
|
import gradio as gr |
|
import numpy as np |
|
import spaces |
|
import os |
|
import torch |
|
import tempfile |
|
import utils |
|
import plotly.graph_objects as go |
|
from io import BytesIO |
|
from PIL import Image |
|
import base64 |
|
import sys |
|
import csv |
|
csv.field_size_limit(1048576) |
|
|
|
from image_segmenter import ImageSegmenter |
|
from monocular_depth_estimator import MonocularDepthEstimator |
|
from point_cloud_generator import display_pcd |
|
|
|
|
|
|
|
|
|
device = torch.device("cpu") |
|
|
|
|
|
img_seg = ImageSegmenter(model_type="yolov8s-seg") |
|
depth_estimator = MonocularDepthEstimator(model_type="midas_v21_small_256") |
|
|
|
|
|
|
|
|
|
|
|
def initialize_gpu(): |
|
"""Ensure ZeroGPU assigns a GPU before initializing CUDA""" |
|
global device |
|
try: |
|
with spaces.GPU(): |
|
torch.cuda.empty_cache() |
|
if torch.cuda.is_available(): |
|
device = torch.device("cuda") |
|
print(f"✅ GPU initialized: {torch.cuda.get_device_name(0)}") |
|
else: |
|
print("❌ No GPU detected after ZeroGPU allocation.") |
|
device = torch.device("cpu") |
|
except Exception as e: |
|
print(f"🚨 GPU initialization failed: {e}") |
|
device = torch.device("cpu") |
|
|
|
|
|
|
|
initialize_gpu() |
|
|
|
|
|
|
|
|
|
|
|
CANCEL_PROCESSING = False |
|
|
|
img_seg = ImageSegmenter(model_type="yolov8s-seg") |
|
depth_estimator = MonocularDepthEstimator(model_type="midas_v21_small_256") |
|
|
|
@spaces.GPU |
|
def process_image(image): |
|
image = utils.resize(image) |
|
image_segmentation, objects_data = img_seg.predict(image) |
|
depthmap, depth_colormap = depth_estimator.make_prediction(image) |
|
dist_image = utils.draw_depth_info(image, depthmap, objects_data) |
|
objs_pcd = utils.generate_obj_pcd(depthmap, objects_data) |
|
plot_fig = display_pcd(objs_pcd) |
|
return image_segmentation, depth_colormap, dist_image, plot_fig |
|
|
|
|
|
@spaces.GPU |
|
def test_process_img(image): |
|
image = utils.resize(image) |
|
image_segmentation, objects_data = img_seg.predict(image) |
|
depthmap, depth_colormap = depth_estimator.make_prediction(image) |
|
return image_segmentation, objects_data, depthmap, depth_colormap |
|
|
|
@spaces.GPU |
|
def process_video(vid_path=None): |
|
vid_cap = cv2.VideoCapture(vid_path) |
|
while vid_cap.isOpened(): |
|
ret, frame = vid_cap.read() |
|
if ret: |
|
print("making predictions ....") |
|
frame = utils.resize(frame) |
|
image_segmentation, objects_data = img_seg.predict(frame) |
|
depthmap, depth_colormap = depth_estimator.make_prediction(frame) |
|
dist_image = utils.draw_depth_info(frame, depthmap, objects_data) |
|
yield cv2.cvtColor(image_segmentation, cv2.COLOR_BGR2RGB), depth_colormap, cv2.cvtColor(dist_image, cv2.COLOR_BGR2RGB) |
|
return None |
|
|
|
|
|
def update_segmentation_options(options): |
|
img_seg.is_show_bounding_boxes = True if 'Show Boundary Box' in options else False |
|
img_seg.is_show_segmentation = True if 'Show Segmentation Region' in options else False |
|
img_seg.is_show_segmentation_boundary = True if 'Show Segmentation Boundary' in options else False |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def update_confidence_threshold(thres_val, img_seg_instance): |
|
"""Update confidence threshold in ImageSegmenter""" |
|
|
|
if thres_val > 1.0: |
|
thres_val = thres_val / 100.0 |
|
img_seg_instance.confidence_threshold = thres_val |
|
print(f"Confidence threshold updated to: {thres_val}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@spaces.GPU |
|
def model_selector(model_type, img_seg_instance, depth_estimator_instance): |
|
if "Small - Better performance and less accuracy" == model_type: |
|
midas_model, yolo_model = "midas_v21_small_256", "yolov8s-seg" |
|
elif "Medium - Balanced performance and accuracy" == model_type: |
|
midas_model, yolo_model = "dpt_hybrid_384", "yolov8m-seg" |
|
elif "Large - Slow performance and high accuracy" == model_type: |
|
midas_model, yolo_model = "dpt_large_384", "yolov8l-seg" |
|
else: |
|
midas_model, yolo_model = "midas_v21_small_256", "yolov8s-seg" |
|
|
|
|
|
img_seg_instance.__init__(model_type=yolo_model) |
|
depth_estimator_instance.__init__(model_type=midas_model) |
|
print(f"Model updated: YOLO={yolo_model}, MiDaS={midas_model}") |
|
|
|
|
|
|
|
|
|
|
|
def get_box_vertices(bbox): |
|
"""Convert bbox to corner vertices""" |
|
x1, y1, x2, y2 = bbox |
|
return [ |
|
[x1, y1], |
|
[x2, y1], |
|
[x2, y2], |
|
[x1, y2] |
|
] |
|
|
|
def depth_at_center(depth_map, bbox): |
|
"""Get depth at center of bounding box""" |
|
x1, y1, x2, y2 = bbox |
|
center_x = int((x1 + x2) / 2) |
|
center_y = int((y1 + y2) / 2) |
|
|
|
|
|
region = depth_map[ |
|
max(0, center_y-2):min(depth_map.shape[0], center_y+3), |
|
max(0, center_x-2):min(depth_map.shape[1], center_x+3) |
|
] |
|
return np.median(region) |
|
|
|
def get_camera_matrix(depth_estimator): |
|
"""Get camera calibration matrix""" |
|
return { |
|
"fx": depth_estimator.fx_depth, |
|
"fy": depth_estimator.fy_depth, |
|
"cx": depth_estimator.cx_depth, |
|
"cy": depth_estimator.cy_depth |
|
} |
|
|
|
def encode_base64_image(image_array): |
|
""" |
|
Encodes a NumPy (OpenCV) image array to a base64-encoded PNG DataURL |
|
like "data:image/png;base64,<...>". |
|
""" |
|
import base64 |
|
import cv2 |
|
|
|
|
|
|
|
|
|
success, encoded_buffer = cv2.imencode(".png", image_array) |
|
if not success: |
|
raise ValueError("Could not encode image to PNG buffer") |
|
|
|
|
|
b64_str = base64.b64encode(encoded_buffer).decode("utf-8") |
|
|
|
|
|
return "data:image/png;base64," + b64_str |
|
|
|
def save_image_to_url(image): |
|
"""Save an OpenCV image to a temporary file and return its URL.""" |
|
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as temp_file: |
|
cv2.imwrite(temp_file.name, image) |
|
return "/".join(temp_file.name.split("/")[-2:]) |
|
|
|
def save_plot_to_url(objs_pcd): |
|
"""Save a Plotly 3D scatter plot to a temporary file and return its URL.""" |
|
fig = go.Figure() |
|
|
|
for data, clr in objs_pcd: |
|
points = np.asarray(data.points) |
|
point_range = range(0, points.shape[0], 1) |
|
|
|
fig.add_trace(go.Scatter3d( |
|
x=points[point_range, 0], |
|
y=points[point_range, 1], |
|
z=points[point_range, 2]*100, |
|
mode='markers', |
|
marker=dict( |
|
size=1, |
|
color='rgb'+str(clr), |
|
opacity=1 |
|
) |
|
)) |
|
|
|
fig.update_layout( |
|
scene=dict( |
|
xaxis_title='X', |
|
yaxis_title='Y', |
|
zaxis_title='Z' |
|
) |
|
) |
|
|
|
with tempfile.NamedTemporaryFile(suffix=".html", delete=False) as temp_file: |
|
fig.write_html(temp_file.name) |
|
return "/".join(temp_file.name.split("/")[-2:]) |
|
|
|
|
|
|
|
def get_3d_position(center, depth, camera_matrix): |
|
"""Project 2D center into 3D space using depth and camera matrix.""" |
|
cx, cy = center |
|
fx, fy = camera_matrix["fx"], camera_matrix["fy"] |
|
cx_d, cy_d = camera_matrix["cx"], camera_matrix["cy"] |
|
|
|
x = (cx - cx_d) * depth / fx |
|
y = (cy - cy_d) * depth / fy |
|
z = depth |
|
|
|
return [x, y, z] |
|
|
|
def get_bbox_from_mask(mask): |
|
"""Get bounding box (x1, y1, x2, y2) from a binary mask.""" |
|
contours, _ = cv2.findContours(mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) |
|
biggest_contour = max(contours, key=cv2.contourArea) |
|
x, y, w, h = cv2.boundingRect(biggest_contour) |
|
return x, y, x+w, y+h |
|
|
|
@spaces.GPU |
|
def get_detection_data(image_data): |
|
global img_seg, depth_estimator |
|
|
|
try: |
|
if isinstance(image_data, dict): |
|
nested_dict = image_data.get("image", {}).get("image", {}) |
|
full_data_url = nested_dict.get("data", "") |
|
|
|
model_size = image_data.get("model_size", "Small - Better performance and less accuracy") |
|
confidence_threshold = image_data.get("confidence_threshold", 0.1) |
|
distance_threshold = image_data.get("distance_threshold", 10.0) |
|
else: |
|
full_data_url = image_data |
|
|
|
model_size = "Small - Better performance and less accuracy" |
|
confidence_threshold = 0.6 |
|
distance_threshold = 10.0 |
|
|
|
if not full_data_url: |
|
return {"error": "No base64 data found in input."} |
|
|
|
if full_data_url.startswith("data:image"): |
|
_, b64_string = full_data_url.split(",", 1) |
|
else: |
|
b64_string = full_data_url |
|
|
|
img_data = base64.b64decode(b64_string) |
|
img = Image.open(BytesIO(img_data)) |
|
img = np.array(img) |
|
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) |
|
|
|
|
|
|
|
|
|
resized_image = utils.resize(img) |
|
print(f"Debug - Resized image shape: {resized_image.shape}") |
|
image = img |
|
print(f"Debug - Original image shape: {image.shape}") |
|
|
|
|
|
model_selector(model_size, img_seg, depth_estimator) |
|
update_confidence_threshold(confidence_threshold, img_seg) |
|
|
|
image_segmentation, objects_data = img_seg.predict(resized_image) |
|
depthmap, depth_colormap = depth_estimator.make_prediction(resized_image) |
|
|
|
detections = [] |
|
for idx, obj in enumerate(objects_data): |
|
|
|
cls_id, cls_name, center, mask, color_bgr, confidence = obj |
|
x1, y1, x2, y2 = get_bbox_from_mask(mask) |
|
|
|
|
|
print(f"Debug - Object {idx}: Original Center = {center}, Original Vertices = {get_box_vertices([x1, y1, x2, y2])}") |
|
|
|
|
|
masked_depth_map, mean_depth = utils.get_masked_depth(depthmap, mask) |
|
print(f"Debug - Object {idx}: Masked depth min/max: {masked_depth_map.min()}, {masked_depth_map.max()}, Mean depth: {mean_depth}") |
|
|
|
|
|
if np.isnan(mean_depth) or not isinstance(mean_depth, (int, float)) or mean_depth <= 0: |
|
print(f"Warning: Invalid mean depth ({mean_depth}) for Object {idx}. Using default depth of 1.0...") |
|
mean_depth = 1.0 |
|
|
|
|
|
real_distance = mean_depth * 10 |
|
|
|
|
|
color_rgb = (int(color_bgr[2]), int(color_bgr[1]), int(color_bgr[0])) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if real_distance <= distance_threshold: |
|
detections.append({ |
|
"class_id": cls_id, |
|
"class_name": cls_name, |
|
"bounding_box": {"vertices": get_box_vertices([x1, y1, x2, y2])}, |
|
"center_2d": center, |
|
"distance": float(real_distance), |
|
"color": color_rgb, |
|
"confidence": float(confidence) |
|
}) |
|
else: |
|
print(f"Debug - Object {idx} filtered out: Distance {real_distance} exceeds threshold {distance_threshold}") |
|
|
|
response = { |
|
"detections": detections, |
|
|
|
|
|
|
|
|
|
|
|
|
|
} |
|
print(f"Debug - Response: {response}") |
|
return response |
|
|
|
except Exception as e: |
|
print(f"🚨 Error in get_detection_data: {str(e)}") |
|
return {"error": str(e)} |
|
|
|
|
|
def cancel(): |
|
CANCEL_PROCESSING = True |
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with gr.Blocks() as my_app: |
|
|
|
|
|
gr.Markdown("<h1><center>Simultaneous Segmentation and Depth Estimation</center></h1>") |
|
gr.Markdown("<h3><center>Created by Vaishanth</center></h3>") |
|
gr.Markdown("<h3><center>This model estimates the depth of segmented objects.</center></h3>") |
|
|
|
|
|
with gr.Tab("Image"): |
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
img_input = gr.Image() |
|
model_type_img = gr.Dropdown( |
|
["Small - Better performance and less accuracy", |
|
"Medium - Balanced performance and accuracy", |
|
"Large - Slow performance and high accuracy"], |
|
label="Model Type", value="Small - Better performance and less accuracy", |
|
info="Select the inference model before running predictions!") |
|
options_checkbox_img = gr.CheckboxGroup(["Show Boundary Box", "Show Segmentation Region", "Show Segmentation Boundary"], label="Options") |
|
conf_thres_img = gr.Slider(1, 100, value=60, label="Confidence Threshold", info="Choose the threshold above which objects should be detected") |
|
submit_btn_img = gr.Button(value="Predict") |
|
|
|
with gr.Column(scale=2): |
|
with gr.Row(): |
|
segmentation_img_output = gr.Image(height=300, label="Segmentation") |
|
depth_img_output = gr.Image(height=300, label="Depth Estimation") |
|
|
|
with gr.Row(): |
|
dist_img_output = gr.Image(height=300, label="Distance") |
|
pcd_img_output = gr.Plot(label="Point Cloud") |
|
|
|
gr.Markdown("## Sample Images") |
|
gr.Examples( |
|
examples=[os.path.join(os.path.dirname(__file__), "assets/images/baggage_claim.jpg"), |
|
os.path.join(os.path.dirname(__file__), "assets/images/kitchen_2.png"), |
|
os.path.join(os.path.dirname(__file__), "assets/images/soccer.jpg"), |
|
os.path.join(os.path.dirname(__file__), "assets/images/room_2.png"), |
|
os.path.join(os.path.dirname(__file__), "assets/images/living_room.jpg")], |
|
inputs=img_input, |
|
outputs=[segmentation_img_output, depth_img_output, dist_img_output, pcd_img_output], |
|
fn=process_image, |
|
cache_examples=False, |
|
|
|
) |
|
|
|
with gr.Tab("Video"): |
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
vid_input = gr.Video() |
|
model_type_vid = gr.Dropdown( |
|
["Small - Better performance and less accuracy", |
|
"Medium - Balanced performance and accuracy", |
|
"Large - Slow performance and high accuracy"], |
|
label="Model Type", value="Small - Better performance and less accuracy", |
|
info="Select the inference model before running predictions!") |
|
|
|
options_checkbox_vid = gr.CheckboxGroup(["Show Boundary Box", "Show Segmentation Region", "Show Segmentation Boundary"], label="Options") |
|
conf_thres_vid = gr.Slider(1, 100, value=60, label="Confidence Threshold", info="Choose the threshold above which objects should be detected") |
|
with gr.Row(): |
|
cancel_btn = gr.Button(value="Cancel") |
|
submit_btn_vid = gr.Button(value="Predict") |
|
|
|
with gr.Column(scale=2): |
|
with gr.Row(): |
|
segmentation_vid_output = gr.Image(height=300, label="Segmentation") |
|
depth_vid_output = gr.Image(height=300, label="Depth Estimation") |
|
|
|
with gr.Row(): |
|
dist_vid_output = gr.Image(height=300, label="Distance") |
|
|
|
gr.Markdown("## Sample Videos") |
|
gr.Examples( |
|
examples=[os.path.join(os.path.dirname(__file__), "assets/videos/input_video.mp4"), |
|
os.path.join(os.path.dirname(__file__), "assets/videos/driving.mp4"), |
|
os.path.join(os.path.dirname(__file__), "assets/videos/overpass.mp4"), |
|
os.path.join(os.path.dirname(__file__), "assets/videos/walking.mp4")], |
|
inputs=vid_input, |
|
|
|
|
|
) |
|
|
|
|
|
with gr.Tab("API", visible=False): |
|
api_input = gr.JSON() |
|
api_output = gr.JSON() |
|
gr.Interface( |
|
fn=get_detection_data, |
|
inputs=api_input, |
|
outputs=api_output, |
|
api_name="get_detection_data" |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
submit_btn_img.click(process_image, inputs=img_input, outputs=[segmentation_img_output, depth_img_output, dist_img_output, pcd_img_output]) |
|
options_checkbox_img.change(update_segmentation_options, options_checkbox_img, []) |
|
conf_thres_img.change(lambda x: update_confidence_threshold(x, img_seg), conf_thres_img, []) |
|
model_type_img.change(lambda x: model_selector(x, img_seg, depth_estimator), model_type_img, []) |
|
|
|
|
|
submit_btn_vid.click(process_video, inputs=vid_input, outputs=[segmentation_vid_output, depth_vid_output, dist_vid_output]) |
|
model_type_vid.change(lambda x: model_selector(x, img_seg, depth_estimator), model_type_vid, []) |
|
cancel_btn.click(cancel, inputs=[], outputs=[]) |
|
options_checkbox_vid.change(update_segmentation_options, options_checkbox_vid, []) |
|
conf_thres_vid.change(lambda x: update_confidence_threshold(x, img_seg), conf_thres_vid, []) |
|
|
|
my_app.queue(max_size=20).launch(share=True) |