Luigi's picture
add a video example
a32a9b3
raw
history blame contribute delete
9.2 kB
#!/usr/bin/env python3
import spaces
import os
import sys
import importlib.util
import re
import gradio as gr
from PIL import Image
import torch
import requests # for downloading remote checkpoints
import shutil
# CUDA info
try:
print(f"CUDA available: {torch.cuda.is_available()}")
print(f"CUDA version: {torch.version.cuda}")
print(f"GPU device: {torch.cuda.get_device_name(0)}")
except:
print('CUDA is not available !')
# β€”β€”β€” Monkey-patch mmdet to remove its mmcv-version assertion β€”β€”β€”
spec = importlib.util.find_spec('mmdet')
if spec and spec.origin:
src = open(spec.origin, encoding='utf-8').read()
patched = re.sub(r'(?ms)^[ \t]*mmcv_minimum_version.*?^__all__', '__all__', src)
m = importlib.util.module_from_spec(spec)
m.__loader__ = spec.loader
m.__file__ = spec.origin
m.__path__ = spec.submodule_search_locations
sys.modules['mmdet'] = m
exec(compile(patched, spec.origin, 'exec'), m.__dict__)
from mmpose.apis.inferencers import MMPoseInferencer
# Remote checkpoints
REMOTE_CHECKPOINTS = {
# COCO-trained
"rtmo-s_8xb32-600e_coco": "https://download.openmmlab.com/mmpose/v1/projects/rtmo/rtmo-s_8xb32-600e_coco-640x640-8db55a59_20231211.pth",
"rtmo-m_16xb16-600e_coco": "https://download.openmmlab.com/mmpose/v1/projects/rtmo/rtmo-m_16xb16-600e_coco-640x640-6f4e0306_20231211.pth",
"rtmo-l_16xb16-600e_coco": "https://download.openmmlab.com/mmpose/v1/projects/rtmo/rtmo-l_16xb16-600e_coco-640x640-516a421f_20231211.pth",
# BODY7-trained
"rtmo-t_8xb32-600e_body7": "https://download.openmmlab.com/mmpose/v1/projects/rtmo/rtmo-t_8xb32-600e_body7-416x416-f48f75cb_20231219.pth",
"rtmo-s_8xb32-600e_body7": "https://download.openmmlab.com/mmpose/v1/projects/rtmo/rtmo-s_8xb32-600e_body7-640x640-dac2bf74_20231211.pth",
"rtmo-m_16xb16-600e_body7": "https://download.openmmlab.com/mmpose/v1/projects/rtmo/rtmo-m_16xb16-600e_body7-640x640-39e78cc4_20231211.pth",
"rtmo-l_16xb16-600e_body7": "https://download.openmmlab.com/mmpose/v1/projects/rtmo/rtmo-l_16xb16-600e_body7-640x640-b37118ce_20231211.pth",
# CrowdPose-trained
"rtmo-s_8xb32-700e_crowdpose": "https://download.openmmlab.com/mmpose/v1/projects/rtmo/rtmo-s_8xb32-700e_crowdpose-640x640-79f81c0d_20231211.pth",
"rtmo-m_16xb16-700e_crowdpose": "https://download.openmmlab.com/mmpose/v1/projects/rtmo/rrtmo-m_16xb16-700e_crowdpose-640x640-0eaf670d_20231211.pth",
"rtmo-l_16xb16-700e_crowdpose": "https://download.openmmlab.com/mmpose/v1/projects/rtmo/rtmo-l_16xb16-700e_crowdpose-640x640-1008211f_20231211.pth",
# Retrainable from HF repo
"rtmo-s_coco_retrainable": "https://huggingface.co/Luigi/Retrainable-RTMO-s/resolve/main/rtmo-s_coco_retrainable.pth",
}
# Variants for inference (prefixes)
VARIANT_PREFIX = {
24: "rtmo-t_8xb32-600e_body7-416x416",
32: "rtmo-s_8xb32-600e_body7-640x640",
48: "rtmo-m_16xb16-600e_body7-640x640",
64: "rtmo-l_16xb16-600e_body7-640x640",
}
# β€”β€”β€” Helper: download checkpoint if remote β€”β€”β€”
def get_checkpoint(path_or_key: str) -> str:
if path_or_key in REMOTE_CHECKPOINTS:
url = REMOTE_CHECKPOINTS[path_or_key]
local_path = f"/tmp/{path_or_key}.pth"
if not os.path.exists(local_path):
r = requests.get(url, stream=True)
with open(local_path, 'wb') as f:
for chunk in r.iter_content(1024):
f.write(chunk)
return local_path
return path_or_key
# β€”β€”β€” Detect variant alias from checkpoint β€”β€”β€”
def detect_rtmo_variant(checkpoint_path: str) -> str:
ckpt = torch.load(checkpoint_path, map_location='cpu')
state_dict = ckpt.get('state_dict', ckpt)
key = 'backbone.stem.conv.conv.weight'
if key not in state_dict:
raise KeyError(f"Cannot find '{key}' in checkpoint.")
out_ch = state_dict[key].shape[0]
return VARIANT_PREFIX.get(out_ch, 'rtmo-s_8xb32-600e_body7-640x640')
# β€”β€”β€” Load inferencer β€”β€”β€”
def load_inferencer(checkpoint_path=None, device=None):
kwargs = {'scope': 'mmpose', 'device': device, 'det_cat_ids': [0]}
if checkpoint_path:
variant = detect_rtmo_variant(checkpoint_path)
kwargs['pose2d'] = variant
kwargs['pose2d_weights'] = checkpoint_path
else:
kwargs['pose2d'] = 'rtmo'
return MMPoseInferencer(**kwargs)
# —─── Prediction function ────
@spaces.GPU()
def predict(image: Image.Image,
video, # new video input
remote_ckpt: str,
upload_ckpt,
bbox_thr: float,
nms_thr: float):
# 1) Write image or pick up video file
if video:
# Gradio Video can come in as a filepath string or dict
if isinstance(video, dict) and 'name' in video:
inp_path = video['name']
elif hasattr(video, "name"):
inp_path = video.name
else:
inp_path = video
else:
inp_path = "/tmp/upload.jpg"
image.save(inp_path)
# 2) Determine by extension if this is video
ext = os.path.splitext(inp_path)[1].lower()
is_video = ext in (".mp4", ".mov", ".avi", ".mkv", ".webm")
# checkpoint selection
if upload_ckpt:
ckpt_path = upload_ckpt.name
active = os.path.basename(ckpt_path)
else:
ckpt_path = get_checkpoint(remote_ckpt)
active = remote_ckpt
# prepare (and clear) output dir
vis_dir = "/tmp/vis"
if os.path.exists(vis_dir):
shutil.rmtree(vis_dir)
os.makedirs(vis_dir, exist_ok=True)
# run inferencer (handles both image & video)
inferencer = load_inferencer(checkpoint_path=ckpt_path, device=None)
for _ in inferencer(
inputs=inp_path,
bbox_thr=bbox_thr,
nms_thr=nms_thr,
pose_based_nms=True,
show=False,
vis_out_dir=vis_dir,
):
pass
# collect and return results
out_files = sorted(os.listdir(vis_dir))
if is_video:
# return only the annotated video path
out_vid = next((f for f in out_files if f.lower().endswith((".mp4", ".mov", ".avi"))), None)
return None, os.path.join(vis_dir, out_vid) if out_vid else None, active
else:
# return only the annotated image
img_f = out_files[0] if out_files else None
vis_img = Image.open(os.path.join(vis_dir, img_f)) if img_f and not img_f.lower().endswith((".mp4", ".mov", ".avi")) else None
return vis_img, None, active
# —─── Gradio UI ────
def main():
with gr.Blocks() as demo:
gr.Markdown("## RTMO Pose Demo")
with gr.Row():
with gr.Column(scale=1, min_width=300):
img_input = gr.Image(type="pil", label="Upload Image")
video_input = gr.Video(label="Upload Video")
remote_dd = gr.Dropdown(
label="Select Remote Checkpoint",
choices=list(REMOTE_CHECKPOINTS.keys()),
value=list(REMOTE_CHECKPOINTS.keys())[0]
)
upload_ckpt = gr.File(file_types=['.pth'], label="Or Upload Your Own Checkpoint (optional)")
bbox_thr = gr.Slider(0.0, 1.0, value=0.1, step=0.01, label="Bounding Box Threshold")
nms_thr = gr.Slider(0.0, 1.0, value=0.65, step=0.01, label="NMS Threshold")
run_btn = gr.Button("Run Inference")
with gr.Column(scale=2):
output_img = gr.Image(type="pil", label="Annotated Image", elem_id="output_image", interactive=False)
output_video = gr.Video(label="Annotated Video", interactive=False)
active_tb = gr.Textbox(label="Active Checkpoint", interactive=False)
# Examples for quick testing
gr.Examples(
examples=[
["https://images.pexels.com/photos/1858175/pexels-photo-1858175.jpeg?auto=compress&cs=tinysrgb&h=614&w=614", None, "rtmo-s_coco_retrainable", None, 0.1, 0.65],
["https://images.pexels.com/photos/3779706/pexels-photo-3779706.jpeg?auto=compress&cs=tinysrgb&h=614&w=614", None, "rtmo-t_8xb32-600e_body7", None, 0.1, 0.65],
["https://images.pexels.com/photos/220453/pexels-photo-220453.jpeg?auto=compress&cs=tinysrgb&h=614&w=614", None, "rtmo-s_8xb32-600e_coco", None, 0.1, 0.65],
# 4th example: public-domain Rip Van Winkle (1896)
[None,
"https://archive.org/download/fred-otts-sneeze/Fred%20Ott%20Sneeze%201894%20GG%20Restore.mp4",
"rtmo-s_coco_retrainable", None, 0.1, 0.65],
],
inputs=[img_input, video_input, remote_dd, upload_ckpt, bbox_thr, nms_thr],
outputs=[output_img, output_video, active_tb],
fn=predict,
cache_examples=False,
label="Examples",
examples_per_page=4
)
run_btn.click(
predict,
inputs=[img_input, video_input, remote_dd, upload_ckpt, bbox_thr, nms_thr],
outputs=[output_img, output_video, active_tb]
)
demo.launch()
if __name__ == "__main__":
main()