Aumkeshchy2003's picture
Update app.py
a29d5e2 verified
raw
history blame
1.3 kB
from fastapi import FastAPI, File, UploadFile
from fastapi.responses import JSONResponse
import torch
from PIL import Image
from torchvision.transforms import functional as F
from yolov5.models.yolo import Model
from yolov5.utils.general import non_max_suppression
app = FastAPI()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = torch.hub.load('ultralytics/yolov5', 'yolov5s', pretrained=True).to(device)
model.eval()
def preprocess_image(image):
image_tensor = F.to_tensor(image)
return image_tensor.unsqueeze(0).to(device)
def draw_boxes(outputs, threshold=0.3):
boxes = []
for box in outputs:
score, label, x1, y1, x2, y2 = box[4].item(), int(box[5].item()), box[0].item(), box[1].item(), box[2].item(), box[3].item()
if score > threshold:
boxes.append({
"label": model.names[label],
"score": score,
"box": [x1, y1, x2, y2]
})
return boxes
@app.post("/predict/")
async def predict(file: UploadFile = File(...)):
image = Image.open(file.file)
image_tensor = preprocess_image(image)
outputs = model(image_tensor)
outputs = non_max_suppression(outputs)[0]
boxes = draw_boxes(outputs)
return JSONResponse(content={"boxes": boxes})