vaishanthr's picture
Update app.py
425bfff
raw
history blame contribute delete
6.14 kB
from ultralytics import YOLO
import cv2
import gradio as gr
import numpy as np
import os
import torch
from image_segmenter import ImageSegmenter
# params
CANCEL_PROCESSING = False
img_seg = ImageSegmenter(model_type="yolov8m-seg-custom")
def resize(image):
"""
resize the input nd array
"""
h, w = image.shape[:2]
if h > w:
return cv2.resize(image, (480, 640))
else:
return cv2.resize(image, (640, 480))
def process_image(image):
image = resize(image)
prediction, _ = img_seg.predict(image)
return prediction
def process_video(vid_path=None):
vid_cap = cv2.VideoCapture(vid_path)
while vid_cap.isOpened():
ret, frame = vid_cap.read()
if ret:
print("Making frame predictions ....")
frame = resize(frame)
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
prediction, _ = img_seg.predict(frame)
yield prediction
return None
def update_segmentation_options(options):
img_seg.is_show_bounding_boxes = True if 'Show Boundary Box' in options else False
img_seg.is_show_segmentation = True if 'Show Segmentation Region' in options else False
img_seg.is_show_segmentation_boundary = True if 'Show Segmentation Boundary' in options else False
def update_confidence_threshold(thres_val):
img_seg.confidence_threshold = thres_val/100
def model_selector(model_type):
if "Small - Better performance and less accuracy" == model_type:
yolo_model = "yolov8s_seg_custom"
elif "Medium - Balanced performance and accuracy" == model_type:
yolo_model = "yolov8m-seg-custom"
elif "Large - Slow performance and high accuracy" == model_type:
yolo_model = "yolov8m-seg-custom"
else:
yolo_model = "yolov8m-seg-custom"
img_seg = ImageSegmenter(model_type=yolo_model)
def cancel():
CANCEL_PROCESSING = True
if __name__ == "__main__":
# gradio gui app
with gr.Blocks() as my_app:
# title
gr.Markdown("<h1><center>Hand detection and segmentation</center></h1>")
# tabs
with gr.Tab("Image"):
with gr.Row():
with gr.Column(scale=1):
img_input = gr.Image()
model_type_img = gr.Dropdown(
["Small - Better performance and less accuracy",
"Medium - Balanced performance and accuracy",
"Large - Slow performance and high accuracy"],
label="Model Type", value="Medium - Balanced performance and accuracy",
info="Select the inference model before running predictions!")
options_checkbox_img = gr.CheckboxGroup(["Show Boundary Box", "Show Segmentation Region"], label="Options")
conf_thres_img = gr.Slider(1, 100, value=60, label="Confidence Threshold", info="Choose the threshold above which objects should be detected")
submit_btn_img = gr.Button(value="Predict")
with gr.Column(scale=2):
with gr.Row():
img_output = gr.Image(height=600, label="Segmentation")
gr.Markdown("## Sample Images")
gr.Examples(
examples=[os.path.join(os.path.dirname(__file__), "assets/images/img_1.jpg"),
os.path.join(os.path.dirname(__file__), "assets/images/img_2.jpg")],
inputs=img_input,
outputs=img_output,
fn=process_image,
cache_examples=True,
)
with gr.Tab("Video"):
with gr.Row():
with gr.Column(scale=1):
vid_input = gr.Video()
model_type_vid = gr.Dropdown(
["Small - Better performance and less accuracy",
"Medium - Balanced performance and accuracy",
"Large - Slow performance and high accuracy"],
label="Model Type", value="Medium - Balanced performance and accuracy",
info="Select the inference model before running predictions!")
options_checkbox_vid = gr.CheckboxGroup(["Show Boundary Box", "Show Segmentation Region"], label="Options")
conf_thres_vid = gr.Slider(1, 100, value=60, label="Confidence Threshold", info="Choose the threshold above which objects should be detected")
with gr.Row():
cancel_btn = gr.Button(value="Cancel")
submit_btn_vid = gr.Button(value="Predict")
with gr.Column(scale=2):
with gr.Row():
vid_output = gr.Image(height=600, label="Segmentation")
gr.Markdown("## Sample Videos")
gr.Examples(
examples=[os.path.join(os.path.dirname(__file__), "assets/videos/vid_1.mp4"),
os.path.join(os.path.dirname(__file__), "assets/videos/vid_2.mp4"),],
inputs=vid_input,
# outputs=vid_output,
# fn=vid_segmenation,
)
# image tab logic
submit_btn_img.click(process_image, inputs=img_input, outputs=img_output)
options_checkbox_img.change(update_segmentation_options, options_checkbox_img, [])
conf_thres_img.change(update_confidence_threshold, conf_thres_img, [])
model_type_img.change(model_selector, model_type_img, [])
# video tab logic
submit_btn_vid.click(process_video, inputs=vid_input, outputs=vid_output)
model_type_vid.change(model_selector, model_type_vid, [])
cancel_btn.click(cancel, inputs=[], outputs=[])
options_checkbox_vid.change(update_segmentation_options, options_checkbox_vid, [])
conf_thres_vid.change(update_confidence_threshold, conf_thres_vid, [])
my_app.queue(concurrency_count=5, max_size=20).launch(debug=True)