Spaces:
Sleeping
Sleeping
import os | |
import time | |
from fastapi import APIRouter, HTTPException | |
from scripts.data_model import ImageInput, ImageOutput | |
from utils.pipeline import load_model | |
router = APIRouter() | |
BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) | |
MODEL_PATH = os.path.join(BASE_DIR, "ml-models", "vit-human-pose-classification/") | |
def image_classification(input: ImageInput)-> ImageOutput: | |
""" | |
Classify the image using a pre-trained model. | |
Args: | |
input (ImageInput): The input data containing the image URL and user ID. | |
Returns: | |
ImageOutput: The output data containing the labels, scores, prediction time, and other info. | |
""" | |
try: | |
pipe = load_model(MODEL_PATH, is_image_model=True) | |
urls = [str(x) for x in input.url] | |
start = time.time() | |
output = pipe(urls) | |
end = time.time() | |
prediction_time = int((end-start)*1000) | |
labels = [x[0]['label'] for x in output] | |
scores = [x[0]['score'] for x in output] | |
return ImageOutput( | |
user_id=input.user_id, | |
url=input.url, | |
model_name="vit-human-pose-classification", | |
label=labels, | |
score=scores, | |
prediction_time=prediction_time | |
) | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=f"Failed to process image classification: {e}") | |