|
import gradio as gr |
|
import os |
|
import tempfile |
|
import cv2 |
|
import numpy as np |
|
from mmdet.apis import DetInferencer |
|
|
|
|
|
inferencer = None |
|
def load_model(config_path, checkpoint_path): |
|
global inferencer |
|
inferencer = DetInferencer(model=config_path, weights=checkpoint_path) |
|
return "Model loaded." |
|
|
|
def infer_image(image): |
|
if inferencer is None: |
|
return "Please load a model first.", None |
|
result = inferencer(image) |
|
vis = result["visualization"] |
|
if isinstance(vis, list): |
|
vis = vis[0] |
|
return "", vis |
|
|
|
def infer_video(video): |
|
if inferencer is None: |
|
return "Please load a model first.", None |
|
temp_dir = tempfile.mkdtemp() |
|
cap = cv2.VideoCapture(video) |
|
fps = cap.get(cv2.CAP_PROP_FPS) |
|
w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) |
|
h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) |
|
out_path = os.path.join(temp_dir, "result.mp4") |
|
fourcc = cv2.VideoWriter_fourcc(*"mp4v") |
|
out = cv2.VideoWriter(out_path, fourcc, fps, (w, h)) |
|
while True: |
|
ret, frame = cap.read() |
|
if not ret: |
|
break |
|
result = inferencer(frame) |
|
vis = result["visualization"] |
|
if isinstance(vis, list): |
|
vis = vis[0] |
|
out.write(vis[:,:,::-1]) |
|
cap.release() |
|
out.release() |
|
return "", out_path |
|
|
|
def ui(): |
|
with gr.Blocks() as demo: |
|
gr.Markdown("# SpecDETR Demo: Image and Video Detection\nUpload your config (.py) and checkpoint (.pth) to start.") |
|
with gr.Row(): |
|
config = gr.File(label="Config File (.py)") |
|
checkpoint = gr.File(label="Checkpoint (.pth)") |
|
load_btn = gr.Button("Load Model") |
|
load_status = gr.Textbox(label="Status", interactive=False) |
|
load_btn.click(load_model, inputs=[config, checkpoint], outputs=load_status) |
|
with gr.Tab("Image"): |
|
img_input = gr.Image(type="numpy") |
|
img_output = gr.Image() |
|
img_btn = gr.Button("Detect on Image") |
|
img_status = gr.Textbox(label="Status", interactive=False) |
|
img_btn.click(infer_image, inputs=img_input, outputs=[img_status, img_output]) |
|
with gr.Tab("Video"): |
|
vid_input = gr.Video() |
|
vid_output = gr.Video() |
|
vid_btn = gr.Button("Detect on Video") |
|
vid_status = gr.Textbox(label="Status", interactive=False) |
|
vid_btn.click(infer_video, inputs=vid_input, outputs=[vid_status, vid_output]) |
|
return demo |
|
|
|
demo = ui() |
|
|
|
def main(): |
|
demo.launch() |
|
|
|
if __name__ == "__main__": |
|
main() |
|
|