Spaces:
Sleeping
Sleeping
from fastapi import FastAPI, UploadFile | |
from ultralytics import YOLOE | |
import io | |
from PIL import Image | |
import numpy as np | |
import os | |
from huggingface_hub import hf_hub_download | |
from ultralytics import YOLO | |
import requests | |
### | |
#pip install -q "git+https://github.com/THU-MIG/yoloe.git#subdirectory=third_party/CLIP" | |
#pip install -q "git+https://github.com/THU-MIG/yoloe.git#subdirectory=third_party/ml-mobileclip" | |
#pip install -q "git+https://github.com/THU-MIG/yoloe.git#subdirectory=third_party/lvis-api" | |
#pip install -q "git+https://github.com/THU-MIG/yoloe.git" | |
#wget -q https://docs-assets.developer.apple.com/ml-research/datasets/mobileclip/mobileclip_blt.pt | |
def init_model(): | |
is_pf=True | |
model_id = "yoloe-11s" | |
# Create a models directory if it doesn't exist | |
os.makedirs("models", exist_ok=True) | |
filename = f"{model_id}-seg.pt" if not is_pf else f"{model_id}-seg-pf.pt" | |
path = hf_hub_download(repo_id="jameslahm/yoloe", filename=filename) | |
local_path = os.path.join("models", path) | |
# Download and load model | |
model = YOLOE(local_path) | |
model.eval() | |
return model | |
app = FastAPI() | |
# Initialize model at startup | |
model = init_model() | |
async def predict(image_url: str, texts: str = "hat"): | |
# Set classes to filter | |
class_list = [text.strip() for text in texts.split(',')] | |
# Download and open image from URL | |
response = requests.get(image_url) | |
image = Image.open(io.BytesIO(response.content)) | |
# Get text embeddings and set classes properly | |
text_embeddings = model.get_text_pe(class_list) | |
model.set_classes(class_list, text_embeddings) | |
# Run inference with the PIL Image | |
results = model.predict(source=image, conf=0.25, iou=0.7) | |
# Extract detection results | |
result = results[0] | |
# print(result) | |
detections = [] | |
for box in result.boxes: | |
detection = { | |
"class": result.names[int(box.cls[0])], | |
"confidence": float(box.conf[0]), | |
"bbox": box.xyxy[0].tolist() # Convert bbox tensor to list | |
} | |
detections.append(detection) | |
print(detections) | |
return {"detections": detections} | |
if __name__ == "__main__": | |
import uvicorn | |
uvicorn.run(app, host="0.0.0.0", port=7860) |