Spaces:
Running
on
Zero
Running
on
Zero
File size: 4,054 Bytes
420fa3e 78d66d3 fbd12ae 420fa3e 4ce6c6b c73b59b 7788a89 c73b59b 78d66d3 4ce6c6b c73b59b 5017de6 c73b59b 420fa3e 5017de6 db79c47 78d66d3 c73b59b a4a6a96 c73b59b c30e671 4ce6c6b 420fa3e c73b59b 5017de6 c73b59b 78d66d3 c73b59b 78d66d3 c73b59b 78d66d3 c73b59b a4a6a96 c73b59b 78d66d3 c73b59b 78d66d3 a4a6a96 5017de6 a4a6a96 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 |
from transformers import pipeline, SamModel, SamProcessor
import torch
import os
import numpy as np
import spaces
import gradio as gr
import shutil
from PIL import Image
def find_cuda():
# Check if CUDA_HOME or CUDA_PATH environment variables are set
cuda_home = os.environ.get('CUDA_HOME') or os.environ.get('CUDA_PATH')
if cuda_home and os.path.exists(cuda_home):
return cuda_home
# Search for the nvcc executable in the system's PATH
nvcc_path = shutil.which('nvcc')
if nvcc_path:
# Remove the 'bin/nvcc' part to get the CUDA installation path
cuda_path = os.path.dirname(os.path.dirname(nvcc_path))
return cuda_path
return None
cuda_path = find_cuda()
if cuda_path:
print(f"CUDA installation found at: {cuda_path}")
else:
print("CUDA installation not found")
# check if cuda is available
device = "cuda" if torch.cuda.is_available() else "cpu"
# we initialize model and processor
checkpoint = "google/owlv2-base-patch16-ensemble"
detector = pipeline(model=checkpoint, task="zero-shot-object-detection", device=device)
sam_model = SamModel.from_pretrained("jadechoghari/robustsam-vit-huge").to(device)
sam_processor = SamProcessor.from_pretrained("jadechoghari/robustsam-vit-huge")
def apply_mask(image, mask, color):
"""Apply a mask to an image with a specific color."""
for c in range(3): # Iterate over RGB channels
image[:, :, c] = np.where(mask, color[c], image[:, :, c])
return image
@spaces.GPU
def query(image, texts, threshold):
texts = texts.split(",")
predictions = detector(
image,
candidate_labels=texts,
threshold=threshold
)
image = np.array(image).copy()
colors = [
(255, 0, 0), # Red
(0, 255, 0), # Green
(0, 0, 255), # Blue
(255, 255, 0), # Yellow
(255, 165, 0), # Orange
(255, 0, 255) # Magenta
]
for i, pred in enumerate(predictions):
score = pred["score"]
if score > 0.5:
box = [round(pred["box"]["xmin"], 2), round(pred["box"]["ymin"], 2),
round(pred["box"]["xmax"], 2), round(pred["box"]["ymax"], 2)]
inputs = sam_processor(
image,
input_boxes=[[[box]]],
return_tensors="pt"
).to(device)
with torch.no_grad():
outputs = sam_model(**inputs)
mask = sam_processor.image_processor.post_process_masks(
outputs.pred_masks.cpu(),
inputs["original_sizes"].cpu(),
inputs["reshaped_input_sizes"].cpu()
)[0][0][0].numpy()
color = colors[i % len(colors)] # cycle through colors
image = apply_mask(image, mask > 0.5, color)
result_image = Image.fromarray(image)
return result_image
title = """
# RobustSAM
"""
description = """
**Welcome to RobustSAM by Snap Research.**
This Space uses **RobustSAM**, a robust version of the Segment Anything Model (SAM) with improved performance on low-quality images while maintaining zero-shot segmentation capabilities.
Thanks to its integration with **OWLv2**, RobustSAM becomes text-promptable, allowing for flexible and accurate segmentation, even with degraded image quality.
Try the example or input an image with comma-separated candidate labels to see the enhanced segmentation results.
For better results, please check the [GitHub repository](https://github.com/robustsam/RobustSAM).
"""
with gr.Blocks() as demo:
gr.Markdown(title)
gr.Markdown(description)
gr.Interface(
query,
inputs=[gr.Image(type="pil", label="Image Input"), gr.Textbox(label="Candidate Labels"), gr.Slider(0, 1, value=0.05, label="Confidence Threshold")],
outputs=gr.Image(type="pil", label="Segmented Image"),
examples=[
["./blur.jpg", "insect", 0.1],
["./lowlight.jpg", "bus, window", 0.1]
],
cache_examples=True
)
demo.launch()
|