File size: 4,250 Bytes
d0342ac
 
 
 
bc120ce
 
e9fc911
bc120ce
30b99b4
d0342ac
bc120ce
 
 
 
8eed836
bc120ce
8eed836
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36b5b5a
8eed836
d0342ac
 
 
 
 
 
 
30b99b4
8eed836
bc120ce
36b5b5a
e9fc911
bc120ce
e96a3aa
bc120ce
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d0342ac
bc120ce
 
 
 
 
30b99b4
 
 
d0342ac
 
 
bc120ce
d0342ac
 
 
30b99b4
 
 
 
 
 
 
 
 
d0342ac
 
bc120ce
d0342ac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import os
import tempfile
from fastapi import FastAPI, UploadFile, File, Form, HTTPException
from fastapi.middleware.cors import CORSMiddleware
import clip
import decord
import nncore
import numpy as np
import torch
import pandas as pd
import torchvision.transforms.functional as F
from decord import VideoReader
from nncore.engine import load_checkpoint
from nncore.nn import build_model
from contextlib import asynccontextmanager

# Global variables for model and config
model, cfg = None, None

# Lifespan handler to manage startup and shutdown
@asynccontextmanager
async def lifespan(app: FastAPI):
    # Startup: Load the model and config
    global model, cfg
    print("Loading model on startup...")
    model, cfg = init_model(CONFIG, WEIGHT)
    print("Model loaded successfully.")
    yield  # Application runs here
    # Shutdown: Clean up (if needed)
    print("Shutting down...")

# Initialize FastAPI app with lifespan
app = FastAPI(title="R2-Tuning API", lifespan=lifespan)

# Enable CORS for React app
app.add_middleware(
    CORSMiddleware,
    allow_origins=["http://localhost:3000"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# Configuration
CONFIG = 'configs/qvhighlights/r2_tuning_qvhighlights.py'
WEIGHT = 'r2_tuning_qvhighlights-ed516355.pth'

def convert_time(seconds):
    minutes, seconds = divmod(round(max(seconds, 0)), 60)
    return f'{minutes:02d}:{seconds:02d}'

def load_video(video_path, cfg):
    decord.bridge.set_bridge('torch')
    vr = VideoReader(video_path)
    stride = vr.get_avg_fps() / cfg.data.val.fps
    fm_idx = [min(round(i), len(vr) - 1) for i in np.arange(0, len(vr), stride).tolist()]
    video = vr.get_batch(fm_idx).permute(0, 3, 1, 2).float() / 255
    size = 336 if '336px' in cfg.model.arch else 224
    h, w = video.size(-2), video.size(-1)
    s = min(h, w)
    x, y = round((h - s) / 2), round((w - s) / 2)
    video = video[..., x:x + s, y:y + s]
    video = F.resize(video, size=(size, size))
    video = F.normalize(video, (0.481, 0.459, 0.408), (0.269, 0.261, 0.276))
    return video.reshape(video.size(0), -1).unsqueeze(0)

def init_model(config, checkpoint):
    cfg = nncore.Config.from_file(config)
    cfg.model.init = True
    model = build_model(cfg.model, dist=False).eval()
    model = load_checkpoint(model, checkpoint, warning=False)
    return model, cfg

def process_video(video_path: str, query: str, model, cfg) -> dict:
    if not query:
        raise ValueError("Text query cannot be empty.")
    try:
        video = load_video(video_path, cfg)
    except Exception as e:
        raise ValueError(f"Failed to load video: {str(e)}")
    query = clip.tokenize(query, truncate=True)
    device = next(model.parameters()).device
    data = dict(video=video.to(device), query=query.to(device), fps=[cfg.data.val.fps])
    with torch.inference_mode():
        pred = model(data)
    mr = pred['_out']['boundary'][:5].cpu().tolist()
    mr = [[convert_time(p[0]), convert_time(p[1]), round(p[2], 2)] for p in mr]
    hd = pred['_out']['saliency'].cpu()
    hd = ((hd - hd.min()) / (hd.max() - hd.min()) * 0.9 + 0.05).tolist()
    hd = [{"x": i * 2, "y": y} for i, y in enumerate(hd)]
    return {"moment_retrieval": mr, "highlight_detection": hd}

@app.post("/predict")
async def predict(video: UploadFile = File(...), query: str = Form(...)):
    try:
        if not video.content_type.startswith("video/"):
            raise HTTPException(status_code=400, detail="Invalid file type. Please upload a video.")
        if not query.strip():
            raise HTTPException(status_code=400, detail="Text query cannot be empty.")
        with tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") as temp_file:
            temp_file.write(await video.read())
            temp_file_path = temp_file.name
        try:
            result = process_video(temp_file_path, query, model, cfg)
            return result
        finally:
            os.unlink(temp_file_path)
    except ValueError as e:
        raise HTTPException(status_code=400, detail=str(e))
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Server error: {str(e)}")

@app.get("/health")
async def health():
    return {"status": "healthy"}