Spaces:
Running
on
Zero
Running
on
Zero
""" | |
Copyright (c) 2024 The D-FINE Authors. All Rights Reserved. | |
""" | |
import gradio as gr | |
import spaces | |
import os | |
import sys | |
import torch | |
import torch.nn as nn | |
import torchvision.transforms as T | |
import supervision as sv | |
from PIL import Image | |
import requests | |
import yaml | |
import numpy as np | |
import gc | |
from src.core import YAMLConfig | |
model_configs = { | |
"dfine_n_coco": | |
{"cfgfile": "configs/dfine/dfine_hgnetv2_n_coco.yml", | |
"classinfofile": "configs/coco.yml", | |
"weights": "https://github.com/Peterande/storage/releases/download/dfinev1.0/dfine_n_coco.pth"}, | |
"dfine_s_coco": | |
{"cfgfile": "configs/dfine/dfine_hgnetv2_s_coco.yml", | |
"classinfofile": "configs/coco.yml", | |
"weights": "https://github.com/Peterande/storage/releases/download/dfinev1.0/dfine_s_coco.pth"}, | |
"dfine_m_coco": | |
{"cfgfile": "configs/dfine/dfine_hgnetv2_m_coco.yml", | |
"classinfofile": "configs/coco.yml", | |
"weights": "https://github.com/Peterande/storage/releases/download/dfinev1.0/dfine_m_coco.pth"}, | |
"dfine_l_coco": | |
{"cfgfile": "configs/dfine/dfine_hgnetv2_l_coco.yml", | |
"classinfofile": "configs/coco.yml", | |
"weights": "https://github.com/Peterande/storage/releases/download/dfinev1.0/dfine_l_coco.pth"}, | |
"dfine_x_coco": | |
{"cfgfile": "configs/dfine/dfine_hgnetv2_x_coco.yml", | |
"classinfofile": "configs/coco.yml", | |
"weights": "https://github.com/Peterande/storage/releases/download/dfinev1.0/dfine_x_coco.pth"}, | |
"dfine_s_obj365": | |
{"cfgfile": "configs/dfine/objects365/dfine_hgnetv2_s_obj365.yml", | |
"classinfofile": "configs/obj365.yml", | |
"weights": "https://github.com/Peterande/storage/releases/download/dfinev1.0/dfine_s_obj365.pth"}, | |
"dfine_m_obj365": | |
{"cfgfile": "configs/dfine/objects365/dfine_hgnetv2_m_obj365.yml", | |
"classinfofile": "configs/obj365.yml", | |
"weights": "https://github.com/Peterande/storage/releases/download/dfinev1.0/dfine_m_obj365.pth"}, | |
"dfine_l_obj365": | |
{"cfgfile": "configs/dfine/objects365/dfine_hgnetv2_l_obj365.yml", | |
"classinfofile": "configs/obj365.yml", | |
"weights": "https://github.com/Peterande/storage/releases/download/dfinev1.0/dfine_l_obj365.pth"}, | |
"dfine_l_obj365_e25": | |
{"cfgfile": "configs/dfine/objects365/dfine_hgnetv2_l_obj365.yml", | |
"classinfofile": "configs/obj365.yml", | |
"weights": "https://github.com/Peterande/storage/releases/download/dfinev1.0/dfine_l_obj365_e25.pth"}, | |
"dfine_x_obj365": | |
{"cfgfile": "configs/dfine/objects365/dfine_hgnetv2_x_obj365.yml", | |
"classinfofile": "configs/obj365.yml", | |
"weights": "https://github.com/Peterande/storage/releases/download/dfinev1.0/dfine_x_obj365.pth"}, | |
"dfine_s_obj2coco": | |
{"cfgfile": "configs/dfine/objects365/dfine_hgnetv2_s_obj2coco.yml", | |
"classinfofile": "configs/coco.yml", | |
"weights": "https://github.com/Peterande/storage/releases/download/dfinev1.0/dfine_s_obj2coco.pth"}, | |
"dfine_m_obj2coco": | |
{"cfgfile": "configs/dfine/objects365/dfine_hgnetv2_m_obj2coco.yml", | |
"classinfofile": "configs/coco.yml", | |
"weights": "https://github.com/Peterande/storage/releases/download/dfinev1.0/dfine_m_obj2coco.pth"}, | |
"dfine_l_obj2coco_e25": | |
{"cfgfile": "configs/dfine/objects365/dfine_hgnetv2_l_obj2coco.yml", | |
"classinfofile": "configs/coco.yml", | |
"weights": "https://github.com/Peterande/storage/releases/download/dfinev1.0/dfine_l_obj2coco_e25.pth"}, | |
"dfine_x_obj2coco": | |
{"cfgfile": "configs/dfine/objects365/dfine_hgnetv2_x_obj2coco.yml", | |
"classinfofile": "configs/coco.yml", | |
"weights": "https://github.com/Peterande/storage/releases/download/dfinev1.0/dfine_x_obj2coco.pth"}, | |
} | |
def download_weights(model_name): | |
"""Download model weights if not already present""" | |
weights_url = model_configs[model_name]["weights"] | |
# Directory path to save weight files | |
weights_dir = os.path.join(os.path.dirname(__file__), "weights") | |
# Weight file path | |
weights_path = os.path.join(weights_dir, model_name + ".pth") | |
# Create weights directory if it doesn't exist | |
if not os.path.exists(weights_dir): | |
os.makedirs(weights_dir) | |
print(f"Created directory: {weights_dir}") | |
# Check if file already exists | |
if os.path.exists(weights_path): | |
print(f"Weights file already exists at: {weights_path}") | |
return weights_path | |
# Download file | |
print(f"Downloading weights from {weights_url} to {weights_path}...") | |
response = requests.get(weights_url, stream=True) | |
response.raise_for_status() # Check for download errors | |
with open(weights_path, 'wb') as f: | |
for chunk in response.iter_content(chunk_size=8192): | |
f.write(chunk) | |
print(f"Downloaded weights to: {weights_path}") | |
return weights_path | |
def process_image_for_gradio(model, device, image, model_name, threshold=0.4): | |
"""Process image function for Gradio interface""" | |
if isinstance(image, np.ndarray): | |
# Convert NumPy array to PIL image | |
im_pil = Image.fromarray(image) | |
else: | |
im_pil = image | |
# Load class information | |
classinfofile = model_configs[model_name]["classinfofile"] | |
classinfo = yaml.load(open(classinfofile, "r"), Loader=yaml.FullLoader)["names"] | |
indexing_method = "0-based" if "coco" in classinfofile else "1-based" | |
w, h = im_pil.size | |
orig_size = torch.tensor([[w, h]]).to(device) | |
transforms = T.Compose( | |
[ | |
T.Resize((640, 640)), | |
T.ToTensor(), | |
] | |
) | |
im_data = transforms(im_pil).unsqueeze(0).to(device) | |
output = model(im_data, orig_size) | |
labels, boxes, scores = output | |
# Visualize results | |
detections = sv.Detections( | |
xyxy=boxes[0].detach().cpu().numpy(), | |
confidence=scores[0].detach().cpu().numpy(), | |
class_id=labels[0].detach().cpu().numpy().astype(int), | |
) | |
detections = detections[detections.confidence > threshold] | |
text_scale = sv.calculate_optimal_text_scale(resolution_wh=im_pil.size) | |
line_thickness = sv.calculate_optimal_line_thickness(resolution_wh=im_pil.size) | |
box_annotator = sv.BoxAnnotator(thickness=line_thickness) | |
label_annotator = sv.LabelAnnotator(text_scale=text_scale, smart_position=True) | |
label_texts = [ | |
f"{classinfo[class_id if indexing_method == '0-based' else class_id - 1]} {confidence:.2f}" | |
for class_id, confidence | |
in zip(detections.class_id, detections.confidence) | |
] | |
result_image = im_pil.copy() | |
result_image = box_annotator.annotate(scene=result_image, detections=detections) | |
result_image = label_annotator.annotate( | |
scene=result_image, | |
detections=detections, | |
labels=label_texts | |
) | |
detection_info = [ | |
f"{classinfo[class_id if indexing_method == '0-based' else class_id - 1]}: {confidence:.2f}, bbox: [{xyxy[0]:.1f}, {xyxy[1]:.1f}, {xyxy[2]:.1f}, {xyxy[3]:.1f}]" | |
for class_id, confidence, xyxy | |
in zip(detections.class_id, detections.confidence, detections.xyxy) | |
] | |
return result_image, "\n".join(detection_info) | |
class ModelWrapper(nn.Module): | |
def __init__(self, cfg): | |
super().__init__() | |
self.model = cfg.model.deploy() | |
self.postprocessor = cfg.postprocessor.deploy() | |
def forward(self, images, orig_target_sizes): | |
outputs = self.model(images) | |
outputs = self.postprocessor(outputs, orig_target_sizes) | |
return outputs | |
# YAMLConfig ํด๋์ค์ ๋ด๋ถ ์ํ๋ฅผ ์ด๊ธฐํํ๋ ํจ์ ์ถ๊ฐ | |
def reset_yaml_config(): | |
"""YAMLConfig ํด๋์ค์ ๋ด๋ถ ์ํ๋ฅผ ์ด๊ธฐํ""" | |
# ํด๋์ค ๋ด๋ถ์ ์บ์ฑ๋ ์ ๋ณด๊ฐ ์๋ค๋ฉด ์ญ์ | |
if hasattr(YAMLConfig, '_instances'): | |
YAMLConfig._instances = {} | |
if hasattr(YAMLConfig, '_configs'): | |
YAMLConfig._configs = {} | |
# ๊ฐ๋ฅํ ๋ค๋ฅธ ๋ชจ๋ ๋ชจ๋ ์บ์ ๋ฆฌ์ | |
import importlib | |
for module_name in list(sys.modules.keys()): | |
if module_name.startswith('src.'): | |
try: | |
importlib.reload(sys.modules[module_name]) | |
except: | |
pass | |
def load_model(model_name): | |
# ๋ชจ๋ธ ๋ก๋ ์ ์ CUDA ์บ์์ ๊ฐ๋น์ง ์ปฌ๋ ์ ์ ๋ฆฌ | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
gc.collect() | |
# YAMLConfig ๋ด๋ถ ์ํ ์ด๊ธฐํ | |
reset_yaml_config() | |
cfgfile = model_configs[model_name]["cfgfile"] | |
weights_path = download_weights(model_name) | |
# ์์ ํ ์๋ก์ด YAMLConfig ์ธ์คํด์ค ์์ฑ | |
cfg = YAMLConfig(cfgfile, resume=weights_path) | |
if "HGNetv2" in cfg.yaml_cfg: | |
cfg.yaml_cfg["HGNetv2"]["pretrained"] = False | |
checkpoint = torch.load(weights_path, map_location="cpu") | |
state = checkpoint["ema"]["module"] if "ema" in checkpoint else checkpoint["model"] | |
# ๋ชจ๋ธ ์์ฑ ์ ํ๋ฒ ๋ ํ์ธ | |
torch.cuda.empty_cache() | |
gc.collect() | |
cfg.model.load_state_dict(state, strict=False) | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
model = ModelWrapper(cfg).to(device) | |
model.eval() | |
return model, device | |
def process_image(image, model_name, confidence_threshold): | |
"""Main processing function for Gradio interface""" | |
# ๋ชจ๋ ์ฌ์ฉ ๊ฐ๋ฅํ CUDA ์ฅ์น ๋ฉ๋ชจ๋ฆฌ ํ๋ณด | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
# ๋ชจ๋ Python ๊ฐ์ฒด ๊ฐ๋น์ง ์ปฌ๋ ์ | |
gc.collect() | |
try: | |
print(f"Loading model: {model_name}") | |
model, device = load_model(model_name) | |
# ์ด๋ฏธ์ง ์ฒ๋ฆฌ | |
result = process_image_for_gradio(model, device, image, model_name, confidence_threshold) | |
# ๋ชจ๋ธ ๊ฐ์ฒด ๋ฐ ๊ด๋ จ ๋ฐ์ดํฐ ๋ช ์์ ์ ๊ฑฐ | |
del model | |
finally: | |
# ํญ์ ๋ฉ๋ชจ๋ฆฌ ์ ๋ฆฌ ๋ณด์ฅ | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
gc.collect() | |
return result | |
# Create Gradio interface | |
demo = gr.Interface( | |
fn=process_image, | |
inputs=[ | |
gr.Image(type="pil", label="Input Image"), | |
gr.Dropdown( | |
choices=list(model_configs.keys()), | |
value="dfine_n_coco", | |
label="Model Selection" | |
), | |
gr.Slider( | |
minimum=0.1, | |
maximum=0.9, | |
value=0.4, | |
step=0.05, | |
label="Confidence Threshold" | |
) | |
], | |
outputs=[ | |
gr.Image(type="pil", label="Detection Result"), | |
gr.Textbox(label="Detected Objects") | |
], | |
title="D-FINE Object Detection Demo", | |
description="Upload an image to see object detection results using the D-FINE model. You can select different models and adjust the confidence threshold.", | |
examples=[ | |
["examples/image1.jpg", "dfine_n_coco", 0.4], | |
] | |
) | |
if __name__ == "__main__": | |
# Launch the Gradio app | |
demo.launch(share=True) |