import os import tempfile from fastapi import FastAPI, UploadFile, File, Form, HTTPException from fastapi.middleware.cors import CORSMiddleware import clip import decord import nncore import numpy as np import torch import pandas as pd import torchvision.transforms.functional as F from decord import VideoReader from nncore.engine import load_checkpoint from nncore.nn import build_model from contextlib import asynccontextmanager # Global variables for model and config model, cfg = None, None # Lifespan handler to manage startup and shutdown @asynccontextmanager async def lifespan(app: FastAPI): # Startup: Load the model and config global model, cfg print("Loading model on startup...") model, cfg = init_model(CONFIG, WEIGHT) print("Model loaded successfully.") yield # Application runs here # Shutdown: Clean up (if needed) print("Shutting down...") # Initialize FastAPI app with lifespan app = FastAPI(title="R2-Tuning API", lifespan=lifespan) # Enable CORS for React app app.add_middleware( CORSMiddleware, allow_origins=["http://localhost:3000"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Configuration CONFIG = 'configs/qvhighlights/r2_tuning_qvhighlights.py' WEIGHT = 'r2_tuning_qvhighlights-ed516355.pth' def convert_time(seconds): minutes, seconds = divmod(round(max(seconds, 0)), 60) return f'{minutes:02d}:{seconds:02d}' def load_video(video_path, cfg): decord.bridge.set_bridge('torch') vr = VideoReader(video_path) stride = vr.get_avg_fps() / cfg.data.val.fps fm_idx = [min(round(i), len(vr) - 1) for i in np.arange(0, len(vr), stride).tolist()] video = vr.get_batch(fm_idx).permute(0, 3, 1, 2).float() / 255 size = 336 if '336px' in cfg.model.arch else 224 h, w = video.size(-2), video.size(-1) s = min(h, w) x, y = round((h - s) / 2), round((w - s) / 2) video = video[..., x:x + s, y:y + s] video = F.resize(video, size=(size, size)) video = F.normalize(video, (0.481, 0.459, 0.408), (0.269, 0.261, 0.276)) return video.reshape(video.size(0), -1).unsqueeze(0) def init_model(config, checkpoint): cfg = nncore.Config.from_file(config) cfg.model.init = True model = build_model(cfg.model, dist=False).eval() model = load_checkpoint(model, checkpoint, warning=False) return model, cfg def process_video(video_path: str, query: str, model, cfg) -> dict: if not query: raise ValueError("Text query cannot be empty.") try: video = load_video(video_path, cfg) except Exception as e: raise ValueError(f"Failed to load video: {str(e)}") query = clip.tokenize(query, truncate=True) device = next(model.parameters()).device data = dict(video=video.to(device), query=query.to(device), fps=[cfg.data.val.fps]) with torch.inference_mode(): pred = model(data) mr = pred['_out']['boundary'][:5].cpu().tolist() mr = [[convert_time(p[0]), convert_time(p[1]), round(p[2], 2)] for p in mr] hd = pred['_out']['saliency'].cpu() hd = ((hd - hd.min()) / (hd.max() - hd.min()) * 0.9 + 0.05).tolist() hd = [{"x": i * 2, "y": y} for i, y in enumerate(hd)] return {"moment_retrieval": mr, "highlight_detection": hd} @app.post("/predict") async def predict(video: UploadFile = File(...), query: str = Form(...)): try: if not video.content_type.startswith("video/"): raise HTTPException(status_code=400, detail="Invalid file type. Please upload a video.") if not query.strip(): raise HTTPException(status_code=400, detail="Text query cannot be empty.") with tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") as temp_file: temp_file.write(await video.read()) temp_file_path = temp_file.name try: result = process_video(temp_file_path, query, model, cfg) return result finally: os.unlink(temp_file_path) except ValueError as e: raise HTTPException(status_code=400, detail=str(e)) except Exception as e: raise HTTPException(status_code=500, detail=f"Server error: {str(e)}") @app.get("/health") async def health(): return {"status": "healthy"}