""" Copyright (c) 2024 The D-FINE Authors. All Rights Reserved. """ import gradio as gr import spaces import os import sys import torch import torch.nn as nn import torchvision.transforms as T import supervision as sv from PIL import Image import requests import yaml import numpy as np import gc from src.core import YAMLConfig model_configs = { "dfine_n_coco": {"cfgfile": "configs/dfine/dfine_hgnetv2_n_coco.yml", "classinfofile": "configs/coco.yml", "weights": "https://github.com/Peterande/storage/releases/download/dfinev1.0/dfine_n_coco.pth"}, "dfine_s_coco": {"cfgfile": "configs/dfine/dfine_hgnetv2_s_coco.yml", "classinfofile": "configs/coco.yml", "weights": "https://github.com/Peterande/storage/releases/download/dfinev1.0/dfine_s_coco.pth"}, "dfine_m_coco": {"cfgfile": "configs/dfine/dfine_hgnetv2_m_coco.yml", "classinfofile": "configs/coco.yml", "weights": "https://github.com/Peterande/storage/releases/download/dfinev1.0/dfine_m_coco.pth"}, "dfine_l_coco": {"cfgfile": "configs/dfine/dfine_hgnetv2_l_coco.yml", "classinfofile": "configs/coco.yml", "weights": "https://github.com/Peterande/storage/releases/download/dfinev1.0/dfine_l_coco.pth"}, "dfine_x_coco": {"cfgfile": "configs/dfine/dfine_hgnetv2_x_coco.yml", "classinfofile": "configs/coco.yml", "weights": "https://github.com/Peterande/storage/releases/download/dfinev1.0/dfine_x_coco.pth"}, "dfine_s_obj365": {"cfgfile": "configs/dfine/objects365/dfine_hgnetv2_s_obj365.yml", "classinfofile": "configs/obj365.yml", "weights": "https://github.com/Peterande/storage/releases/download/dfinev1.0/dfine_s_obj365.pth"}, "dfine_m_obj365": {"cfgfile": "configs/dfine/objects365/dfine_hgnetv2_m_obj365.yml", "classinfofile": "configs/obj365.yml", "weights": "https://github.com/Peterande/storage/releases/download/dfinev1.0/dfine_m_obj365.pth"}, "dfine_l_obj365": {"cfgfile": "configs/dfine/objects365/dfine_hgnetv2_l_obj365.yml", "classinfofile": "configs/obj365.yml", "weights": "https://github.com/Peterande/storage/releases/download/dfinev1.0/dfine_l_obj365.pth"}, "dfine_l_obj365_e25": {"cfgfile": "configs/dfine/objects365/dfine_hgnetv2_l_obj365.yml", "classinfofile": "configs/obj365.yml", "weights": "https://github.com/Peterande/storage/releases/download/dfinev1.0/dfine_l_obj365_e25.pth"}, "dfine_x_obj365": {"cfgfile": "configs/dfine/objects365/dfine_hgnetv2_x_obj365.yml", "classinfofile": "configs/obj365.yml", "weights": "https://github.com/Peterande/storage/releases/download/dfinev1.0/dfine_x_obj365.pth"}, "dfine_s_obj2coco": {"cfgfile": "configs/dfine/objects365/dfine_hgnetv2_s_obj2coco.yml", "classinfofile": "configs/coco.yml", "weights": "https://github.com/Peterande/storage/releases/download/dfinev1.0/dfine_s_obj2coco.pth"}, "dfine_m_obj2coco": {"cfgfile": "configs/dfine/objects365/dfine_hgnetv2_m_obj2coco.yml", "classinfofile": "configs/coco.yml", "weights": "https://github.com/Peterande/storage/releases/download/dfinev1.0/dfine_m_obj2coco.pth"}, "dfine_l_obj2coco_e25": {"cfgfile": "configs/dfine/objects365/dfine_hgnetv2_l_obj2coco.yml", "classinfofile": "configs/coco.yml", "weights": "https://github.com/Peterande/storage/releases/download/dfinev1.0/dfine_l_obj2coco_e25.pth"}, "dfine_x_obj2coco": {"cfgfile": "configs/dfine/objects365/dfine_hgnetv2_x_obj2coco.yml", "classinfofile": "configs/coco.yml", "weights": "https://github.com/Peterande/storage/releases/download/dfinev1.0/dfine_x_obj2coco.pth"}, } def download_weights(model_name): """Download model weights if not already present""" weights_url = model_configs[model_name]["weights"] # Directory path to save weight files weights_dir = os.path.join(os.path.dirname(__file__), "weights") # Weight file path weights_path = os.path.join(weights_dir, model_name + ".pth") # Create weights directory if it doesn't exist if not os.path.exists(weights_dir): os.makedirs(weights_dir) print(f"Created directory: {weights_dir}") # Check if file already exists if os.path.exists(weights_path): print(f"Weights file already exists at: {weights_path}") return weights_path # Download file print(f"Downloading weights from {weights_url} to {weights_path}...") response = requests.get(weights_url, stream=True) response.raise_for_status() # Check for download errors with open(weights_path, 'wb') as f: for chunk in response.iter_content(chunk_size=8192): f.write(chunk) print(f"Downloaded weights to: {weights_path}") return weights_path @torch.no_grad() def process_image_for_gradio(model, device, image, model_name, threshold=0.4): """Process image function for Gradio interface""" if isinstance(image, np.ndarray): # Convert NumPy array to PIL image im_pil = Image.fromarray(image) else: im_pil = image # Load class information classinfofile = model_configs[model_name]["classinfofile"] classinfo = yaml.load(open(classinfofile, "r"), Loader=yaml.FullLoader)["names"] indexing_method = "0-based" if "coco" in classinfofile else "1-based" w, h = im_pil.size orig_size = torch.tensor([[w, h]]).to(device) transforms = T.Compose( [ T.Resize((640, 640)), T.ToTensor(), ] ) im_data = transforms(im_pil).unsqueeze(0).to(device) output = model(im_data, orig_size) labels, boxes, scores = output # Visualize results detections = sv.Detections( xyxy=boxes[0].detach().cpu().numpy(), confidence=scores[0].detach().cpu().numpy(), class_id=labels[0].detach().cpu().numpy().astype(int), ) detections = detections[detections.confidence > threshold] text_scale = sv.calculate_optimal_text_scale(resolution_wh=im_pil.size) line_thickness = sv.calculate_optimal_line_thickness(resolution_wh=im_pil.size) box_annotator = sv.BoxAnnotator(thickness=line_thickness) label_annotator = sv.LabelAnnotator(text_scale=text_scale, smart_position=True) label_texts = [ f"{classinfo[class_id if indexing_method == '0-based' else class_id - 1]} {confidence:.2f}" for class_id, confidence in zip(detections.class_id, detections.confidence) ] result_image = im_pil.copy() result_image = box_annotator.annotate(scene=result_image, detections=detections) result_image = label_annotator.annotate( scene=result_image, detections=detections, labels=label_texts ) detection_info = [ f"{classinfo[class_id if indexing_method == '0-based' else class_id - 1]}: {confidence:.2f}, bbox: [{xyxy[0]:.1f}, {xyxy[1]:.1f}, {xyxy[2]:.1f}, {xyxy[3]:.1f}]" for class_id, confidence, xyxy in zip(detections.class_id, detections.confidence, detections.xyxy) ] return result_image, "\n".join(detection_info) class ModelWrapper(nn.Module): def __init__(self, cfg): super().__init__() self.model = cfg.model.deploy() self.postprocessor = cfg.postprocessor.deploy() def forward(self, images, orig_target_sizes): outputs = self.model(images) outputs = self.postprocessor(outputs, orig_target_sizes) return outputs # YAMLConfig 클래스의 내부 상태를 초기화하는 함수 추가 def reset_yaml_config(): """YAMLConfig 클래스의 내부 상태를 초기화""" # 클래스 내부에 캐싱된 정보가 있다면 삭제 if hasattr(YAMLConfig, '_instances'): YAMLConfig._instances = {} if hasattr(YAMLConfig, '_configs'): YAMLConfig._configs = {} # 가능한 다른 모든 모듈 캐시 리셋 import importlib for module_name in list(sys.modules.keys()): if module_name.startswith('src.'): try: importlib.reload(sys.modules[module_name]) except: pass def load_model(model_name): # 모델 로드 전에 CUDA 캐시와 가비지 컬렉션 정리 if torch.cuda.is_available(): torch.cuda.empty_cache() gc.collect() # YAMLConfig 내부 상태 초기화 reset_yaml_config() cfgfile = model_configs[model_name]["cfgfile"] weights_path = download_weights(model_name) # 완전히 새로운 YAMLConfig 인스턴스 생성 cfg = YAMLConfig(cfgfile, resume=weights_path) if "HGNetv2" in cfg.yaml_cfg: cfg.yaml_cfg["HGNetv2"]["pretrained"] = False checkpoint = torch.load(weights_path, map_location="cpu") state = checkpoint["ema"]["module"] if "ema" in checkpoint else checkpoint["model"] # 모델 생성 전 한번 더 확인 torch.cuda.empty_cache() gc.collect() cfg.model.load_state_dict(state, strict=False) device = "cuda" if torch.cuda.is_available() else "cpu" model = ModelWrapper(cfg).to(device) model.eval() return model, device @spaces.GPU def process_image(image, model_name, confidence_threshold): """Main processing function for Gradio interface""" # 모든 사용 가능한 CUDA 장치 메모리 확보 if torch.cuda.is_available(): torch.cuda.empty_cache() # 모든 Python 객체 가비지 컬렉션 gc.collect() try: print(f"Loading model: {model_name}") model, device = load_model(model_name) # 이미지 처리 result = process_image_for_gradio(model, device, image, model_name, confidence_threshold) # 모델 객체 및 관련 데이터 명시적 제거 del model finally: # 항상 메모리 정리 보장 if torch.cuda.is_available(): torch.cuda.empty_cache() gc.collect() return result # Create Gradio interface demo = gr.Interface( fn=process_image, inputs=[ gr.Image(type="pil", label="Input Image"), gr.Dropdown( choices=list(model_configs.keys()), value="dfine_n_coco", label="Model Selection" ), gr.Slider( minimum=0.1, maximum=0.9, value=0.4, step=0.05, label="Confidence Threshold" ) ], outputs=[ gr.Image(type="pil", label="Detection Result"), gr.Textbox(label="Detected Objects") ], title="D-FINE Object Detection Demo", description="Upload an image to see object detection results using the D-FINE model. You can select different models and adjust the confidence threshold.", examples=[ ["examples/image1.jpg", "dfine_n_coco", 0.4], ] ) if __name__ == "__main__": # Launch the Gradio app demo.launch(share=True)