Spaces:
Sleeping
Sleeping
import os | |
import tempfile | |
from fastapi import FastAPI, UploadFile, File, Form, HTTPException | |
from fastapi.middleware.cors import CORSMiddleware | |
import clip | |
import decord | |
import nncore | |
import numpy as np | |
import torch | |
import pandas as pd | |
import torchvision.transforms.functional as F | |
from decord import VideoReader | |
from nncore.engine import load_checkpoint | |
from nncore.nn import build_model | |
from contextlib import asynccontextmanager | |
# Global variables for model and config | |
model, cfg = None, None | |
# Lifespan handler to manage startup and shutdown | |
async def lifespan(app: FastAPI): | |
# Startup: Load the model and config | |
global model, cfg | |
print("Loading model on startup...") | |
model, cfg = init_model(CONFIG, WEIGHT) | |
print("Model loaded successfully.") | |
yield # Application runs here | |
# Shutdown: Clean up (if needed) | |
print("Shutting down...") | |
# Initialize FastAPI app with lifespan | |
app = FastAPI(title="R2-Tuning API", lifespan=lifespan) | |
# Enable CORS for React app | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["http://localhost:3000"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
# Configuration | |
CONFIG = 'configs/qvhighlights/r2_tuning_qvhighlights.py' | |
WEIGHT = 'r2_tuning_qvhighlights-ed516355.pth' | |
def convert_time(seconds): | |
minutes, seconds = divmod(round(max(seconds, 0)), 60) | |
return f'{minutes:02d}:{seconds:02d}' | |
def load_video(video_path, cfg): | |
decord.bridge.set_bridge('torch') | |
vr = VideoReader(video_path) | |
stride = vr.get_avg_fps() / cfg.data.val.fps | |
fm_idx = [min(round(i), len(vr) - 1) for i in np.arange(0, len(vr), stride).tolist()] | |
video = vr.get_batch(fm_idx).permute(0, 3, 1, 2).float() / 255 | |
size = 336 if '336px' in cfg.model.arch else 224 | |
h, w = video.size(-2), video.size(-1) | |
s = min(h, w) | |
x, y = round((h - s) / 2), round((w - s) / 2) | |
video = video[..., x:x + s, y:y + s] | |
video = F.resize(video, size=(size, size)) | |
video = F.normalize(video, (0.481, 0.459, 0.408), (0.269, 0.261, 0.276)) | |
return video.reshape(video.size(0), -1).unsqueeze(0) | |
def init_model(config, checkpoint): | |
cfg = nncore.Config.from_file(config) | |
cfg.model.init = True | |
model = build_model(cfg.model, dist=False).eval() | |
model = load_checkpoint(model, checkpoint, warning=False) | |
return model, cfg | |
def process_video(video_path: str, query: str, model, cfg) -> dict: | |
if not query: | |
raise ValueError("Text query cannot be empty.") | |
try: | |
video = load_video(video_path, cfg) | |
except Exception as e: | |
raise ValueError(f"Failed to load video: {str(e)}") | |
query = clip.tokenize(query, truncate=True) | |
device = next(model.parameters()).device | |
data = dict(video=video.to(device), query=query.to(device), fps=[cfg.data.val.fps]) | |
with torch.inference_mode(): | |
pred = model(data) | |
mr = pred['_out']['boundary'][:5].cpu().tolist() | |
mr = [[convert_time(p[0]), convert_time(p[1]), round(p[2], 2)] for p in mr] | |
hd = pred['_out']['saliency'].cpu() | |
hd = ((hd - hd.min()) / (hd.max() - hd.min()) * 0.9 + 0.05).tolist() | |
hd = [{"x": i * 2, "y": y} for i, y in enumerate(hd)] | |
return {"moment_retrieval": mr, "highlight_detection": hd} | |
async def predict(video: UploadFile = File(...), query: str = Form(...)): | |
try: | |
if not video.content_type.startswith("video/"): | |
raise HTTPException(status_code=400, detail="Invalid file type. Please upload a video.") | |
if not query.strip(): | |
raise HTTPException(status_code=400, detail="Text query cannot be empty.") | |
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") as temp_file: | |
temp_file.write(await video.read()) | |
temp_file_path = temp_file.name | |
try: | |
result = process_video(temp_file_path, query, model, cfg) | |
return result | |
finally: | |
os.unlink(temp_file_path) | |
except ValueError as e: | |
raise HTTPException(status_code=400, detail=str(e)) | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=f"Server error: {str(e)}") | |
async def health(): | |
return {"status": "healthy"} |