import gradio as gr import torch from PIL import Image import cv2 import numpy as np from transformers import CLIPProcessor, CLIPModel from ultralytics import FastSAM import supervision as sv import os import requests from tqdm.auto import tqdm # For a nice progress bar # --- Constants and Model Initialization --- # CLIP CLIP_MODEL_NAME = "openai/clip-vit-base-patch32" # FastSAM # *Corrected* HuggingFace link for the weights FASTSAM_WEIGHTS_URL = "https://huggingface.co/spaces/An-619/FastSAM/resolve/6f76f474c656d2cb29599f49c296a8784b02d04b/weights/FastSAM-s.pt" FASTSAM_WEIGHTS_NAME = "FastSAM-s.pt" # Default FastSAM parameters DEFAULT_IMGSZ = 640 DEFAULT_CONFIDENCE = 0.4 DEFAULT_IOU = 0.9 DEFAULT_RETINA_MASKS = False # --- Helper Functions --- def download_file(url, filename): """Downloads a file from a URL with a progress bar.""" response = requests.get(url, stream=True) response.raise_for_status() # Raise an exception for bad status codes total_size = int(response.headers.get('content-length', 0)) block_size = 1024 # 1 KB progress_bar = tqdm(total=total_size, unit='iB', unit_scale=True) with open(filename, 'wb') as file: for data in response.iter_content(block_size): progress_bar.update(len(data)) file.write(data) progress_bar.close() if total_size != 0 and progress_bar.n != total_size: raise ValueError("Error: Download failed.") # --- Model Loading --- # Load CLIP model (this part is correct in your original code) model = CLIPModel.from_pretrained(CLIP_MODEL_NAME) processor = CLIPProcessor.from_pretrained(CLIP_MODEL_NAME) # Load FastSAM model with dynamic device handling if not os.path.exists(FASTSAM_WEIGHTS_NAME): print(f"Downloading FastSAM weights from {FASTSAM_WEIGHTS_URL}...") try: download_file(FASTSAM_WEIGHTS_URL, FASTSAM_WEIGHTS_NAME) print("FastSAM weights downloaded successfully.") except Exception as e: print(f"Error downloading FastSAM weights: {e}") raise # Re-raise the exception to stop execution device = torch.device("cuda" if torch.cuda.is_available() else "cpu") fast_sam = FastSAM(FASTSAM_WEIGHTS_NAME) fast_sam.to(device) print(f"FastSAM loaded on device: {device}") # --- Processing Functions --- def process_image_clip(image, text_input): # ... (Your CLIP processing function remains the same) ... if image is None: return "Please upload an image first." if not text_input: return "Please enter some text to check in the image." try: # Convert numpy array to PIL Image if needed if isinstance(image, np.ndarray): image = Image.fromarray(image) # Create a list of candidate labels candidate_labels = [text_input, f"not {text_input}"] # Process image and text inputs = processor( images=image, text=candidate_labels, return_tensors="pt", padding=True ) # Get model predictions outputs = model(**{k: v for k, v in inputs.items()}) logits_per_image = outputs.logits_per_image probs = logits_per_image.softmax(dim=1) # Get confidence for the positive label confidence = float(probs[0][0]) return f"Confidence that the image contains '{text_input}': {confidence:.2%}" except Exception as e: return f"Error processing image: {str(e)}" def process_image_fastsam(image, imgsz, conf, iou, retina_masks): if image is None: return None, "Please upload an image to segment." try: # Convert PIL image to numpy array if needed if isinstance(image, Image.Image): image_np = np.array(image) else: image_np = image # Run FastSAM inference results = fast_sam(image_np, device=device, retina_masks=retina_masks, imgsz=imgsz, conf=conf, iou=iou) # Check if results are valid if results is None or len(results) == 0 or results[0] is None: return None, "FastSAM did not return valid results. Try adjusting parameters or using a different image." # Get detections detections = sv.Detections.from_ultralytics(results[0]) # Check if detections are valid if detections is None or len(detections) == 0: return None, "No objects detected in the image. Try lowering the confidence threshold." # Create annotator box_annotator = sv.BoxAnnotator() mask_annotator = sv.MaskAnnotator() # Annotate image annotated_image = mask_annotator.annotate(scene=image_np.copy(), detections=detections) annotated_image = box_annotator.annotate(scene=annotated_image, detections=detections) return Image.fromarray(annotated_image), None # Return None for the error message since there's no error except RuntimeError as re: if "out of memory" in str(re).lower(): return None, "Error: Out of memory. Try reducing the image size (imgsz) or disabling retina masks." else: return None, f"Runtime error during FastSAM processing: {str(re)}" except Exception as e: return None, f"Error processing image with FastSAM: {str(e)}" # --- Gradio Interface --- with gr.Blocks(css="footer {visibility: hidden}") as demo: # ... (Your Markdown and CLIP tab remain mostly the same) ... gr.Markdown(""" # CLIP and FastSAM Demo This demo combines two powerful AI models: - **CLIP**: For zero-shot image classification - **FastSAM**: For automatic image segmentation Try uploading an image and use either of the tabs below! """) with gr.Tab("CLIP Zero-Shot Classification"): with gr.Row(): image_input = gr.Image(label="Input Image") text_input = gr.Textbox( label="What do you want to check in the image?", placeholder="e.g., 'a dog', 'sunset', 'people playing'", info="Enter any concept you want to check in the image" ) output_text = gr.Textbox(label="Result") classify_btn = gr.Button("Classify") classify_btn.click(fn=process_image_clip, inputs=[image_input, text_input], outputs=output_text) gr.Examples( examples=[ ["https://raw.githubusercontent.com/gradio-app/gradio/main/demo/kitchen/kitchen.png", "kitchen"], ["https://raw.githubusercontent.com/gradio-app/gradio/main/demo/calculator/calculator.jpg", "calculator"], ], inputs=[image_input, text_input], ) with gr.Tab("FastSAM Segmentation"): with gr.Row(): image_input_sam = gr.Image(label="Input Image") with gr.Column(): imgsz_slider = gr.Slider(minimum=320, maximum=1920, step=32, value=DEFAULT_IMGSZ, label="Image Size (imgsz)") conf_slider = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, value=DEFAULT_CONFIDENCE, label="Confidence Threshold") iou_slider = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, value=DEFAULT_IOU, label="IoU Threshold") retina_checkbox = gr.Checkbox(label="Retina Masks", value=DEFAULT_RETINA_MASKS) with gr.Row(): image_output = gr.Image(label="Segmentation Result") error_output = gr.Textbox(label="Error Message", type="text") # Added for displaying errors segment_btn = gr.Button("Segment") segment_btn.click( fn=process_image_fastsam, inputs=[image_input_sam, imgsz_slider, conf_slider, iou_slider, retina_checkbox], outputs=[image_output, error_output] # Output to both image and error textboxes ) gr.Examples( examples=[ ["https://raw.githubusercontent.com/gradio-app/gradio/main/demo/kitchen/kitchen.png"], ["https://raw.githubusercontent.com/gradio-app/gradio/main/demo/calculator/calculator.jpg"], ], inputs=[image_input_sam], ) # ... (Your final Markdown remains the same) ... gr.Markdown(""" ### How to use: 1. **CLIP Classification**: Upload an image and enter text to check if that concept exists in the image 2. **FastSAM Segmentation**: Upload an image to get automatic segmentation with bounding boxes and masks ### Note: - The models run on CPU by default, so processing might take a few seconds. If you have a GPU, it will be used automatically. - For best results, use clear images with good lighting. - You can adjust FastSAM parameters (Image Size, Confidence, IoU, Retina Masks) in the Segmentation tab. """) demo.launch(share=True)