Spaces:
Running
on
Zero
Running
on
Zero
#!/usr/bin/env python3 | |
import spaces | |
import os | |
import sys | |
import importlib.util | |
import re | |
import gradio as gr | |
from PIL import Image | |
import torch | |
import requests # for downloading remote checkpoints | |
import shutil | |
# CUDA info | |
try: | |
print(f"CUDA available: {torch.cuda.is_available()}") | |
print(f"CUDA version: {torch.version.cuda}") | |
print(f"GPU device: {torch.cuda.get_device_name(0)}") | |
except: | |
print('CUDA is not available !') | |
# βββ Monkey-patch mmdet to remove its mmcv-version assertion βββ | |
spec = importlib.util.find_spec('mmdet') | |
if spec and spec.origin: | |
src = open(spec.origin, encoding='utf-8').read() | |
patched = re.sub(r'(?ms)^[ \t]*mmcv_minimum_version.*?^__all__', '__all__', src) | |
m = importlib.util.module_from_spec(spec) | |
m.__loader__ = spec.loader | |
m.__file__ = spec.origin | |
m.__path__ = spec.submodule_search_locations | |
sys.modules['mmdet'] = m | |
exec(compile(patched, spec.origin, 'exec'), m.__dict__) | |
from mmpose.apis.inferencers import MMPoseInferencer | |
# Remote checkpoints | |
REMOTE_CHECKPOINTS = { | |
# COCO-trained | |
"rtmo-s_8xb32-600e_coco": "https://download.openmmlab.com/mmpose/v1/projects/rtmo/rtmo-s_8xb32-600e_coco-640x640-8db55a59_20231211.pth", | |
"rtmo-m_16xb16-600e_coco": "https://download.openmmlab.com/mmpose/v1/projects/rtmo/rtmo-m_16xb16-600e_coco-640x640-6f4e0306_20231211.pth", | |
"rtmo-l_16xb16-600e_coco": "https://download.openmmlab.com/mmpose/v1/projects/rtmo/rtmo-l_16xb16-600e_coco-640x640-516a421f_20231211.pth", | |
# BODY7-trained | |
"rtmo-t_8xb32-600e_body7": "https://download.openmmlab.com/mmpose/v1/projects/rtmo/rtmo-t_8xb32-600e_body7-416x416-f48f75cb_20231219.pth", | |
"rtmo-s_8xb32-600e_body7": "https://download.openmmlab.com/mmpose/v1/projects/rtmo/rtmo-s_8xb32-600e_body7-640x640-dac2bf74_20231211.pth", | |
"rtmo-m_16xb16-600e_body7": "https://download.openmmlab.com/mmpose/v1/projects/rtmo/rtmo-m_16xb16-600e_body7-640x640-39e78cc4_20231211.pth", | |
"rtmo-l_16xb16-600e_body7": "https://download.openmmlab.com/mmpose/v1/projects/rtmo/rtmo-l_16xb16-600e_body7-640x640-b37118ce_20231211.pth", | |
# CrowdPose-trained | |
"rtmo-s_8xb32-700e_crowdpose": "https://download.openmmlab.com/mmpose/v1/projects/rtmo/rtmo-s_8xb32-700e_crowdpose-640x640-79f81c0d_20231211.pth", | |
"rtmo-m_16xb16-700e_crowdpose": "https://download.openmmlab.com/mmpose/v1/projects/rtmo/rrtmo-m_16xb16-700e_crowdpose-640x640-0eaf670d_20231211.pth", | |
"rtmo-l_16xb16-700e_crowdpose": "https://download.openmmlab.com/mmpose/v1/projects/rtmo/rtmo-l_16xb16-700e_crowdpose-640x640-1008211f_20231211.pth", | |
# Retrainable from HF repo | |
"rtmo-s_coco_retrainable": "https://huggingface.co/Luigi/Retrainable-RTMO-s/resolve/main/rtmo-s_coco_retrainable.pth", | |
} | |
# Variants for inference (prefixes) | |
VARIANT_PREFIX = { | |
24: "rtmo-t_8xb32-600e_body7-416x416", | |
32: "rtmo-s_8xb32-600e_body7-640x640", | |
48: "rtmo-m_16xb16-600e_body7-640x640", | |
64: "rtmo-l_16xb16-600e_body7-640x640", | |
} | |
# βββ Helper: download checkpoint if remote βββ | |
def get_checkpoint(path_or_key: str) -> str: | |
if path_or_key in REMOTE_CHECKPOINTS: | |
url = REMOTE_CHECKPOINTS[path_or_key] | |
local_path = f"/tmp/{path_or_key}.pth" | |
if not os.path.exists(local_path): | |
r = requests.get(url, stream=True) | |
with open(local_path, 'wb') as f: | |
for chunk in r.iter_content(1024): | |
f.write(chunk) | |
return local_path | |
return path_or_key | |
# βββ Detect variant alias from checkpoint βββ | |
def detect_rtmo_variant(checkpoint_path: str) -> str: | |
ckpt = torch.load(checkpoint_path, map_location='cpu') | |
state_dict = ckpt.get('state_dict', ckpt) | |
key = 'backbone.stem.conv.conv.weight' | |
if key not in state_dict: | |
raise KeyError(f"Cannot find '{key}' in checkpoint.") | |
out_ch = state_dict[key].shape[0] | |
return VARIANT_PREFIX.get(out_ch, 'rtmo-s_8xb32-600e_body7-640x640') | |
# βββ Load inferencer βββ | |
def load_inferencer(checkpoint_path=None, device=None): | |
kwargs = {'scope': 'mmpose', 'device': device, 'det_cat_ids': [0]} | |
if checkpoint_path: | |
variant = detect_rtmo_variant(checkpoint_path) | |
kwargs['pose2d'] = variant | |
kwargs['pose2d_weights'] = checkpoint_path | |
else: | |
kwargs['pose2d'] = 'rtmo' | |
return MMPoseInferencer(**kwargs) | |
# ββββ Prediction function ββββ | |
def predict(image: Image.Image, | |
video, # new video input | |
remote_ckpt: str, | |
upload_ckpt, | |
bbox_thr: float, | |
nms_thr: float): | |
# 1) Write image or pick up video file | |
if video: | |
# Gradio Video can come in as a filepath string or dict | |
if isinstance(video, dict) and 'name' in video: | |
inp_path = video['name'] | |
elif hasattr(video, "name"): | |
inp_path = video.name | |
else: | |
inp_path = video | |
else: | |
inp_path = "/tmp/upload.jpg" | |
image.save(inp_path) | |
# 2) Determine by extension if this is video | |
ext = os.path.splitext(inp_path)[1].lower() | |
is_video = ext in (".mp4", ".mov", ".avi", ".mkv", ".webm") | |
# checkpoint selection | |
if upload_ckpt: | |
ckpt_path = upload_ckpt.name | |
active = os.path.basename(ckpt_path) | |
else: | |
ckpt_path = get_checkpoint(remote_ckpt) | |
active = remote_ckpt | |
# prepare (and clear) output dir | |
vis_dir = "/tmp/vis" | |
if os.path.exists(vis_dir): | |
shutil.rmtree(vis_dir) | |
os.makedirs(vis_dir, exist_ok=True) | |
# run inferencer (handles both image & video) | |
inferencer = load_inferencer(checkpoint_path=ckpt_path, device=None) | |
for _ in inferencer( | |
inputs=inp_path, | |
bbox_thr=bbox_thr, | |
nms_thr=nms_thr, | |
pose_based_nms=True, | |
show=False, | |
vis_out_dir=vis_dir, | |
): | |
pass | |
# collect and return results | |
out_files = sorted(os.listdir(vis_dir)) | |
if is_video: | |
# return only the annotated video path | |
out_vid = next((f for f in out_files if f.lower().endswith((".mp4", ".mov", ".avi"))), None) | |
return None, os.path.join(vis_dir, out_vid) if out_vid else None, active | |
else: | |
# return only the annotated image | |
img_f = out_files[0] if out_files else None | |
vis_img = Image.open(os.path.join(vis_dir, img_f)) if img_f and not img_f.lower().endswith((".mp4", ".mov", ".avi")) else None | |
return vis_img, None, active | |
# ββββ Gradio UI ββββ | |
def main(): | |
with gr.Blocks() as demo: | |
gr.Markdown("## RTMO Pose Demo") | |
with gr.Row(): | |
with gr.Column(scale=1, min_width=300): | |
img_input = gr.Image(type="pil", label="Upload Image") | |
video_input = gr.Video(label="Upload Video") | |
remote_dd = gr.Dropdown( | |
label="Select Remote Checkpoint", | |
choices=list(REMOTE_CHECKPOINTS.keys()), | |
value=list(REMOTE_CHECKPOINTS.keys())[0] | |
) | |
upload_ckpt = gr.File(file_types=['.pth'], label="Or Upload Your Own Checkpoint (optional)") | |
bbox_thr = gr.Slider(0.0, 1.0, value=0.1, step=0.01, label="Bounding Box Threshold") | |
nms_thr = gr.Slider(0.0, 1.0, value=0.65, step=0.01, label="NMS Threshold") | |
run_btn = gr.Button("Run Inference") | |
with gr.Column(scale=2): | |
output_img = gr.Image(type="pil", label="Annotated Image", elem_id="output_image", interactive=False) | |
output_video = gr.Video(label="Annotated Video", interactive=False) | |
active_tb = gr.Textbox(label="Active Checkpoint", interactive=False) | |
# Examples for quick testing | |
gr.Examples( | |
examples=[ | |
["https://images.pexels.com/photos/1858175/pexels-photo-1858175.jpeg?auto=compress&cs=tinysrgb&h=614&w=614", None, "rtmo-s_coco_retrainable", None, 0.1, 0.65], | |
["https://images.pexels.com/photos/3779706/pexels-photo-3779706.jpeg?auto=compress&cs=tinysrgb&h=614&w=614", None, "rtmo-t_8xb32-600e_body7", None, 0.1, 0.65], | |
["https://images.pexels.com/photos/220453/pexels-photo-220453.jpeg?auto=compress&cs=tinysrgb&h=614&w=614", None, "rtmo-s_8xb32-600e_coco", None, 0.1, 0.65], | |
# 4th example: public-domain Rip Van Winkle (1896) | |
[None, | |
"https://archive.org/download/fred-otts-sneeze/Fred%20Ott%20Sneeze%201894%20GG%20Restore.mp4", | |
"rtmo-s_coco_retrainable", None, 0.1, 0.65], | |
], | |
inputs=[img_input, video_input, remote_dd, upload_ckpt, bbox_thr, nms_thr], | |
outputs=[output_img, output_video, active_tb], | |
fn=predict, | |
cache_examples=False, | |
label="Examples", | |
examples_per_page=4 | |
) | |
run_btn.click( | |
predict, | |
inputs=[img_input, video_input, remote_dd, upload_ckpt, bbox_thr, nms_thr], | |
outputs=[output_img, output_video, active_tb] | |
) | |
demo.launch() | |
if __name__ == "__main__": | |
main() |