Spaces:
Sleeping
Sleeping
File size: 4,250 Bytes
d0342ac bc120ce e9fc911 bc120ce 30b99b4 d0342ac bc120ce 8eed836 bc120ce 8eed836 36b5b5a 8eed836 d0342ac 30b99b4 8eed836 bc120ce 36b5b5a e9fc911 bc120ce e96a3aa bc120ce d0342ac bc120ce 30b99b4 d0342ac bc120ce d0342ac 30b99b4 d0342ac bc120ce d0342ac |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 |
import os
import tempfile
from fastapi import FastAPI, UploadFile, File, Form, HTTPException
from fastapi.middleware.cors import CORSMiddleware
import clip
import decord
import nncore
import numpy as np
import torch
import pandas as pd
import torchvision.transforms.functional as F
from decord import VideoReader
from nncore.engine import load_checkpoint
from nncore.nn import build_model
from contextlib import asynccontextmanager
# Global variables for model and config
model, cfg = None, None
# Lifespan handler to manage startup and shutdown
@asynccontextmanager
async def lifespan(app: FastAPI):
# Startup: Load the model and config
global model, cfg
print("Loading model on startup...")
model, cfg = init_model(CONFIG, WEIGHT)
print("Model loaded successfully.")
yield # Application runs here
# Shutdown: Clean up (if needed)
print("Shutting down...")
# Initialize FastAPI app with lifespan
app = FastAPI(title="R2-Tuning API", lifespan=lifespan)
# Enable CORS for React app
app.add_middleware(
CORSMiddleware,
allow_origins=["http://localhost:3000"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Configuration
CONFIG = 'configs/qvhighlights/r2_tuning_qvhighlights.py'
WEIGHT = 'r2_tuning_qvhighlights-ed516355.pth'
def convert_time(seconds):
minutes, seconds = divmod(round(max(seconds, 0)), 60)
return f'{minutes:02d}:{seconds:02d}'
def load_video(video_path, cfg):
decord.bridge.set_bridge('torch')
vr = VideoReader(video_path)
stride = vr.get_avg_fps() / cfg.data.val.fps
fm_idx = [min(round(i), len(vr) - 1) for i in np.arange(0, len(vr), stride).tolist()]
video = vr.get_batch(fm_idx).permute(0, 3, 1, 2).float() / 255
size = 336 if '336px' in cfg.model.arch else 224
h, w = video.size(-2), video.size(-1)
s = min(h, w)
x, y = round((h - s) / 2), round((w - s) / 2)
video = video[..., x:x + s, y:y + s]
video = F.resize(video, size=(size, size))
video = F.normalize(video, (0.481, 0.459, 0.408), (0.269, 0.261, 0.276))
return video.reshape(video.size(0), -1).unsqueeze(0)
def init_model(config, checkpoint):
cfg = nncore.Config.from_file(config)
cfg.model.init = True
model = build_model(cfg.model, dist=False).eval()
model = load_checkpoint(model, checkpoint, warning=False)
return model, cfg
def process_video(video_path: str, query: str, model, cfg) -> dict:
if not query:
raise ValueError("Text query cannot be empty.")
try:
video = load_video(video_path, cfg)
except Exception as e:
raise ValueError(f"Failed to load video: {str(e)}")
query = clip.tokenize(query, truncate=True)
device = next(model.parameters()).device
data = dict(video=video.to(device), query=query.to(device), fps=[cfg.data.val.fps])
with torch.inference_mode():
pred = model(data)
mr = pred['_out']['boundary'][:5].cpu().tolist()
mr = [[convert_time(p[0]), convert_time(p[1]), round(p[2], 2)] for p in mr]
hd = pred['_out']['saliency'].cpu()
hd = ((hd - hd.min()) / (hd.max() - hd.min()) * 0.9 + 0.05).tolist()
hd = [{"x": i * 2, "y": y} for i, y in enumerate(hd)]
return {"moment_retrieval": mr, "highlight_detection": hd}
@app.post("/predict")
async def predict(video: UploadFile = File(...), query: str = Form(...)):
try:
if not video.content_type.startswith("video/"):
raise HTTPException(status_code=400, detail="Invalid file type. Please upload a video.")
if not query.strip():
raise HTTPException(status_code=400, detail="Text query cannot be empty.")
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") as temp_file:
temp_file.write(await video.read())
temp_file_path = temp_file.name
try:
result = process_video(temp_file_path, query, model, cfg)
return result
finally:
os.unlink(temp_file_path)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
raise HTTPException(status_code=500, detail=f"Server error: {str(e)}")
@app.get("/health")
async def health():
return {"status": "healthy"} |