Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,53 +1,43 @@
|
|
1 |
-
|
2 |
-
import
|
3 |
-
from
|
4 |
-
|
5 |
import clip
|
6 |
import decord
|
7 |
-
import gradio as gr
|
8 |
import nncore
|
9 |
import numpy as np
|
10 |
import torch
|
|
|
11 |
import torchvision.transforms.functional as F
|
12 |
from decord import VideoReader
|
13 |
from nncore.engine import load_checkpoint
|
14 |
from nncore.nn import build_model
|
15 |
|
16 |
-
|
17 |
-
|
18 |
-
TITLE = '🌀R2-Tuning: Efficient Image-to-Video Transfer Learning for Video Temporal Grounding'
|
19 |
|
20 |
-
|
21 |
-
|
22 |
-
|
|
|
|
|
|
|
|
|
23 |
|
24 |
CONFIG = 'configs/qvhighlights/r2_tuning_qvhighlights.py'
|
25 |
WEIGHT = 'r2_tuning_qvhighlights-ed516355.pth'
|
26 |
|
27 |
-
|
28 |
-
EXAMPLES = [
|
29 |
-
('data/gTAvxnQtjXM_60.0_210.0.mp4', 'A man in a white t shirt wearing a backpack is showing a nearby cathedral.'),
|
30 |
-
('data/pA6Z-qYhSNg_210.0_360.0.mp4', 'Different Facebook posts on transgender bathrooms are shown.'),
|
31 |
-
('data/CkWOpyrAXdw_210.0_360.0.mp4', 'Indian girl cleaning her kitchen before cooking.'),
|
32 |
-
('data/ocLUzCNodj4_360.0_510.0.mp4', 'A woman stands in her bedroom in front of a mirror and talks.'),
|
33 |
-
('data/HkLfNhgP0TM_660.0_810.0.mp4', 'Woman lays down on the couch while talking to the camera.')
|
34 |
-
]
|
35 |
-
# yapf:enable
|
36 |
-
|
37 |
|
38 |
def convert_time(seconds):
|
39 |
minutes, seconds = divmod(round(max(seconds, 0)), 60)
|
40 |
return f'{minutes:02d}:{seconds:02d}'
|
41 |
|
42 |
-
|
43 |
def load_video(video_path, cfg):
|
44 |
decord.bridge.set_bridge('torch')
|
45 |
-
|
46 |
vr = VideoReader(video_path)
|
47 |
stride = vr.get_avg_fps() / cfg.data.val.fps
|
48 |
fm_idx = [min(round(i), len(vr) - 1) for i in np.arange(0, len(vr), stride).tolist()]
|
49 |
video = vr.get_batch(fm_idx).permute(0, 3, 1, 2).float() / 255
|
50 |
-
|
51 |
size = 336 if '336px' in cfg.model.arch else 224
|
52 |
h, w = video.size(-2), video.size(-1)
|
53 |
s = min(h, w)
|
@@ -55,81 +45,60 @@ def load_video(video_path, cfg):
|
|
55 |
video = video[..., x:x + s, y:y + s]
|
56 |
video = F.resize(video, size=(size, size))
|
57 |
video = F.normalize(video, (0.481, 0.459, 0.408), (0.269, 0.261, 0.276))
|
58 |
-
|
59 |
-
|
60 |
-
return video
|
61 |
-
|
62 |
|
63 |
def init_model(config, checkpoint):
|
64 |
cfg = nncore.Config.from_file(config)
|
65 |
cfg.model.init = True
|
66 |
-
|
67 |
-
if checkpoint.startswith('http'):
|
68 |
-
checkpoint = nncore.download(checkpoint, out_dir='checkpoints', verbose=False)
|
69 |
-
|
70 |
model = build_model(cfg.model, dist=False).eval()
|
71 |
model = load_checkpoint(model, checkpoint, warning=False)
|
72 |
-
|
73 |
return model, cfg
|
74 |
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
raise gr.Error('Text query can not be empty.')
|
79 |
-
|
80 |
try:
|
81 |
-
video = load_video(
|
82 |
-
except Exception:
|
83 |
-
raise
|
84 |
-
|
85 |
query = clip.tokenize(query, truncate=True)
|
86 |
-
|
87 |
device = next(model.parameters()).device
|
88 |
data = dict(video=video.to(device), query=query.to(device), fps=[cfg.data.val.fps])
|
89 |
-
|
90 |
with torch.inference_mode():
|
91 |
pred = model(data)
|
92 |
-
|
93 |
mr = pred['_out']['boundary'][:5].cpu().tolist()
|
94 |
mr = [[convert_time(p[0]), convert_time(p[1]), round(p[2], 2)] for p in mr]
|
95 |
-
|
96 |
hd = pred['_out']['saliency'].cpu()
|
97 |
hd = ((hd - hd.min()) / (hd.max() - hd.min()) * 0.9 + 0.05).tolist()
|
98 |
-
hd =
|
99 |
-
|
100 |
-
return mr, hd
|
101 |
-
|
102 |
|
103 |
-
|
|
|
|
|
|
|
|
|
104 |
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
x_title='Time (seconds)',
|
129 |
-
y_title='Saliency Score',
|
130 |
-
label='Highlight Detection')
|
131 |
-
|
132 |
-
random_btn.click(lambda: random.sample(EXAMPLES, 1)[0], None, [video, query])
|
133 |
-
submit_btn.click(fn, [video, query], [mr, hd])
|
134 |
-
|
135 |
-
demo.launch()
|
|
|
1 |
+
import os
|
2 |
+
import tempfile
|
3 |
+
from fastapi import FastAPI, UploadFile, File, Form, HTTPException
|
4 |
+
from fastapi.middleware.cors import CORSMiddleware
|
5 |
import clip
|
6 |
import decord
|
|
|
7 |
import nncore
|
8 |
import numpy as np
|
9 |
import torch
|
10 |
+
import pandas as pd
|
11 |
import torchvision.transforms.functional as F
|
12 |
from decord import VideoReader
|
13 |
from nncore.engine import load_checkpoint
|
14 |
from nncore.nn import build_model
|
15 |
|
16 |
+
app = FastAPI(title="R2-Tuning API")
|
|
|
|
|
17 |
|
18 |
+
app.add_middleware(
|
19 |
+
CORSMiddleware,
|
20 |
+
allow_origins=["http://localhost:3000"],
|
21 |
+
allow_credentials=True,
|
22 |
+
allow_methods=["*"],
|
23 |
+
allow_headers=["*"],
|
24 |
+
)
|
25 |
|
26 |
CONFIG = 'configs/qvhighlights/r2_tuning_qvhighlights.py'
|
27 |
WEIGHT = 'r2_tuning_qvhighlights-ed516355.pth'
|
28 |
|
29 |
+
model, cfg = None, None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
|
31 |
def convert_time(seconds):
|
32 |
minutes, seconds = divmod(round(max(seconds, 0)), 60)
|
33 |
return f'{minutes:02d}:{seconds:02d}'
|
34 |
|
|
|
35 |
def load_video(video_path, cfg):
|
36 |
decord.bridge.set_bridge('torch')
|
|
|
37 |
vr = VideoReader(video_path)
|
38 |
stride = vr.get_avg_fps() / cfg.data.val.fps
|
39 |
fm_idx = [min(round(i), len(vr) - 1) for i in np.arange(0, len(vr), stride).tolist()]
|
40 |
video = vr.get_batch(fm_idx).permute(0, 3, 1, 2).float() / 255
|
|
|
41 |
size = 336 if '336px' in cfg.model.arch else 224
|
42 |
h, w = video.size(-2), video.size(-1)
|
43 |
s = min(h, w)
|
|
|
45 |
video = video[..., x:x + s, y:y + s]
|
46 |
video = F.resize(video, size=(size, size))
|
47 |
video = F.normalize(video, (0.481, 0.459, 0.408), (0.269, 0.261, 0.276))
|
48 |
+
return video.reshape(video.size(0), -1).unsqueeze(0)
|
|
|
|
|
|
|
49 |
|
50 |
def init_model(config, checkpoint):
|
51 |
cfg = nncore.Config.from_file(config)
|
52 |
cfg.model.init = True
|
|
|
|
|
|
|
|
|
53 |
model = build_model(cfg.model, dist=False).eval()
|
54 |
model = load_checkpoint(model, checkpoint, warning=False)
|
|
|
55 |
return model, cfg
|
56 |
|
57 |
+
def process_video(video_path: str, query: str, model, cfg) -> dict:
|
58 |
+
if not query:
|
59 |
+
raise ValueError("Text query cannot be empty.")
|
|
|
|
|
60 |
try:
|
61 |
+
video = load_video(video_path, cfg)
|
62 |
+
except Exception as e:
|
63 |
+
raise ValueError(f"Failed to load video: {str(e)}")
|
|
|
64 |
query = clip.tokenize(query, truncate=True)
|
|
|
65 |
device = next(model.parameters()).device
|
66 |
data = dict(video=video.to(device), query=query.to(device), fps=[cfg.data.val.fps])
|
|
|
67 |
with torch.inference_mode():
|
68 |
pred = model(data)
|
|
|
69 |
mr = pred['_out']['boundary'][:5].cpu().tolist()
|
70 |
mr = [[convert_time(p[0]), convert_time(p[1]), round(p[2], 2)] for p in mr]
|
|
|
71 |
hd = pred['_out']['saliency'].cpu()
|
72 |
hd = ((hd - hd.min()) / (hd.max() - hd.min()) * 0.9 + 0.05).tolist()
|
73 |
+
hd = [{"x": i * 2, "y": y} for i, y in enumerate(hd)]
|
74 |
+
return {"moment_retrieval": mr, "highlight_detection": hd}
|
|
|
|
|
75 |
|
76 |
+
@app.on_event("startup")
|
77 |
+
async def startup_event():
|
78 |
+
global model, cfg
|
79 |
+
model, cfg = init_model(CONFIG, WEIGHT)
|
80 |
+
print("Model loaded successfully.")
|
81 |
|
82 |
+
@app.post("/predict")
|
83 |
+
async def predict(video: UploadFile = File(...), query: str = Form(...)):
|
84 |
+
try:
|
85 |
+
if not video.content_type.startswith("video/"):
|
86 |
+
raise HTTPException(status_code=400, detail="Invalid file type. Please upload a video.")
|
87 |
+
if not query.strip():
|
88 |
+
raise HTTPException(status_code=400, detail="Text query cannot be empty.")
|
89 |
+
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") as temp_file:
|
90 |
+
temp_file.write(await video.read())
|
91 |
+
temp_file_path = temp_file.name
|
92 |
+
try:
|
93 |
+
result = process_video(temp_file_path, query, model, cfg)
|
94 |
+
return result
|
95 |
+
finally:
|
96 |
+
os.unlink(temp_file_path)
|
97 |
+
except ValueError as e:
|
98 |
+
raise HTTPException(status_code=400, detail=str(e))
|
99 |
+
except Exception as e:
|
100 |
+
raise HTTPException(status_code=500, detail=f"Server error: {str(e)}")
|
101 |
+
|
102 |
+
@app.get("/health")
|
103 |
+
async def health():
|
104 |
+
return {"status": "healthy"}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|