fastapi-yoloe / app.py
wjm55
init
3b05c4f
raw
history blame
2.28 kB
from fastapi import FastAPI, UploadFile
from ultralytics import YOLOE
import io
from PIL import Image
import numpy as np
import os
from huggingface_hub import hf_hub_download
from ultralytics import YOLO
import requests
###
#pip install -q "git+https://github.com/THU-MIG/yoloe.git#subdirectory=third_party/CLIP"
#pip install -q "git+https://github.com/THU-MIG/yoloe.git#subdirectory=third_party/ml-mobileclip"
#pip install -q "git+https://github.com/THU-MIG/yoloe.git#subdirectory=third_party/lvis-api"
#pip install -q "git+https://github.com/THU-MIG/yoloe.git"
#wget -q https://docs-assets.developer.apple.com/ml-research/datasets/mobileclip/mobileclip_blt.pt
def init_model():
is_pf=True
model_id = "yoloe-11s"
# Create a models directory if it doesn't exist
os.makedirs("models", exist_ok=True)
filename = f"{model_id}-seg.pt" if not is_pf else f"{model_id}-seg-pf.pt"
path = hf_hub_download(repo_id="jameslahm/yoloe", filename=filename)
local_path = os.path.join("models", path)
# Download and load model
model = YOLOE(local_path)
model.eval()
return model
app = FastAPI()
# Initialize model at startup
model = init_model()
@app.post("/predict")
async def predict(image_url: str, texts: str = "hat"):
# Set classes to filter
class_list = [text.strip() for text in texts.split(',')]
# Download and open image from URL
response = requests.get(image_url)
image = Image.open(io.BytesIO(response.content))
# Get text embeddings and set classes properly
text_embeddings = model.get_text_pe(class_list)
model.set_classes(class_list, text_embeddings)
# Run inference with the PIL Image
results = model.predict(source=image, conf=0.25, iou=0.7)
# Extract detection results
result = results[0]
# print(result)
detections = []
for box in result.boxes:
detection = {
"class": result.names[int(box.cls[0])],
"confidence": float(box.conf[0]),
"bbox": box.xyxy[0].tolist() # Convert bbox tensor to list
}
detections.append(detection)
print(detections)
return {"detections": detections}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)