MisbahKhan commited on
Commit
d0342ac
·
verified ·
1 Parent(s): 63f5061

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +51 -82
app.py CHANGED
@@ -1,53 +1,43 @@
1
-
2
- import random
3
- from functools import partial
4
-
5
  import clip
6
  import decord
7
- import gradio as gr
8
  import nncore
9
  import numpy as np
10
  import torch
 
11
  import torchvision.transforms.functional as F
12
  from decord import VideoReader
13
  from nncore.engine import load_checkpoint
14
  from nncore.nn import build_model
15
 
16
- import pandas as pd
17
-
18
- TITLE = '🌀R2-Tuning: Efficient Image-to-Video Transfer Learning for Video Temporal Grounding'
19
 
20
- TITLE_MD = '<h1 align="center">🌀R<sup>2</sup>-Tuning: Efficient Image-to-Video Transfer Learning for Video Temporal Grounding</h1>'
21
- DESCRIPTION_MD = 'R<sup>2</sup>-Tuning is a parameter- and memory-efficient transfer learning method for video temporal grounding. Please find more details in our <a href="https://arxiv.org/abs/2404.00801" target="_blank">Tech Report</a> and <a href="https://github.com/yeliudev/R2-Tuning" target="_blank">GitHub Repo</a>.'
22
- GUIDE_MD = '### User Guide:\n1. Upload a video or click "random" to sample one.\n2. Input a text query. A good practice is to write a sentence with 5~15 words.\n3. Click "submit" and you\'ll see the moment retrieval and highlight detection results on the right.'
 
 
 
 
23
 
24
  CONFIG = 'configs/qvhighlights/r2_tuning_qvhighlights.py'
25
  WEIGHT = 'r2_tuning_qvhighlights-ed516355.pth'
26
 
27
- # yapf:disable
28
- EXAMPLES = [
29
- ('data/gTAvxnQtjXM_60.0_210.0.mp4', 'A man in a white t shirt wearing a backpack is showing a nearby cathedral.'),
30
- ('data/pA6Z-qYhSNg_210.0_360.0.mp4', 'Different Facebook posts on transgender bathrooms are shown.'),
31
- ('data/CkWOpyrAXdw_210.0_360.0.mp4', 'Indian girl cleaning her kitchen before cooking.'),
32
- ('data/ocLUzCNodj4_360.0_510.0.mp4', 'A woman stands in her bedroom in front of a mirror and talks.'),
33
- ('data/HkLfNhgP0TM_660.0_810.0.mp4', 'Woman lays down on the couch while talking to the camera.')
34
- ]
35
- # yapf:enable
36
-
37
 
38
  def convert_time(seconds):
39
  minutes, seconds = divmod(round(max(seconds, 0)), 60)
40
  return f'{minutes:02d}:{seconds:02d}'
41
 
42
-
43
  def load_video(video_path, cfg):
44
  decord.bridge.set_bridge('torch')
45
-
46
  vr = VideoReader(video_path)
47
  stride = vr.get_avg_fps() / cfg.data.val.fps
48
  fm_idx = [min(round(i), len(vr) - 1) for i in np.arange(0, len(vr), stride).tolist()]
49
  video = vr.get_batch(fm_idx).permute(0, 3, 1, 2).float() / 255
50
-
51
  size = 336 if '336px' in cfg.model.arch else 224
52
  h, w = video.size(-2), video.size(-1)
53
  s = min(h, w)
@@ -55,81 +45,60 @@ def load_video(video_path, cfg):
55
  video = video[..., x:x + s, y:y + s]
56
  video = F.resize(video, size=(size, size))
57
  video = F.normalize(video, (0.481, 0.459, 0.408), (0.269, 0.261, 0.276))
58
- video = video.reshape(video.size(0), -1).unsqueeze(0)
59
-
60
- return video
61
-
62
 
63
  def init_model(config, checkpoint):
64
  cfg = nncore.Config.from_file(config)
65
  cfg.model.init = True
66
-
67
- if checkpoint.startswith('http'):
68
- checkpoint = nncore.download(checkpoint, out_dir='checkpoints', verbose=False)
69
-
70
  model = build_model(cfg.model, dist=False).eval()
71
  model = load_checkpoint(model, checkpoint, warning=False)
72
-
73
  return model, cfg
74
 
75
-
76
- def main(video, query, model, cfg):
77
- if len(query) == 0:
78
- raise gr.Error('Text query can not be empty.')
79
-
80
  try:
81
- video = load_video(video, cfg)
82
- except Exception:
83
- raise gr.Error('Failed to load the video.')
84
-
85
  query = clip.tokenize(query, truncate=True)
86
-
87
  device = next(model.parameters()).device
88
  data = dict(video=video.to(device), query=query.to(device), fps=[cfg.data.val.fps])
89
-
90
  with torch.inference_mode():
91
  pred = model(data)
92
-
93
  mr = pred['_out']['boundary'][:5].cpu().tolist()
94
  mr = [[convert_time(p[0]), convert_time(p[1]), round(p[2], 2)] for p in mr]
95
-
96
  hd = pred['_out']['saliency'].cpu()
97
  hd = ((hd - hd.min()) / (hd.max() - hd.min()) * 0.9 + 0.05).tolist()
98
- hd = pd.DataFrame(dict(x=range(0, len(hd) * 2, 2), y=hd))
99
-
100
- return mr, hd
101
-
102
 
103
- model, cfg = init_model(CONFIG, WEIGHT)
 
 
 
 
104
 
105
- fn = partial(main, model=model, cfg=cfg)
106
-
107
- with gr.Blocks(title=TITLE) as demo:
108
- gr.Markdown(TITLE_MD)
109
- gr.Markdown(DESCRIPTION_MD)
110
- gr.Markdown(GUIDE_MD)
111
-
112
- with gr.Row():
113
- with gr.Column():
114
- video = gr.Video(label='Video')
115
- query = gr.Textbox(label='Text Query')
116
-
117
- with gr.Row():
118
- random_btn = gr.Button(value='🔮 Random')
119
- gr.ClearButton([video, query], value='🗑️ Reset')
120
- submit_btn = gr.Button(value='🚀 Submit')
121
-
122
- with gr.Column():
123
- mr = gr.DataFrame(
124
- headers=['Start Time', 'End Time', 'Score'], label='Moment Retrieval')
125
- hd = gr.LinePlot(
126
- x='x',
127
- y='y',
128
- x_title='Time (seconds)',
129
- y_title='Saliency Score',
130
- label='Highlight Detection')
131
-
132
- random_btn.click(lambda: random.sample(EXAMPLES, 1)[0], None, [video, query])
133
- submit_btn.click(fn, [video, query], [mr, hd])
134
-
135
- demo.launch()
 
1
+ import os
2
+ import tempfile
3
+ from fastapi import FastAPI, UploadFile, File, Form, HTTPException
4
+ from fastapi.middleware.cors import CORSMiddleware
5
  import clip
6
  import decord
 
7
  import nncore
8
  import numpy as np
9
  import torch
10
+ import pandas as pd
11
  import torchvision.transforms.functional as F
12
  from decord import VideoReader
13
  from nncore.engine import load_checkpoint
14
  from nncore.nn import build_model
15
 
16
+ app = FastAPI(title="R2-Tuning API")
 
 
17
 
18
+ app.add_middleware(
19
+ CORSMiddleware,
20
+ allow_origins=["http://localhost:3000"],
21
+ allow_credentials=True,
22
+ allow_methods=["*"],
23
+ allow_headers=["*"],
24
+ )
25
 
26
  CONFIG = 'configs/qvhighlights/r2_tuning_qvhighlights.py'
27
  WEIGHT = 'r2_tuning_qvhighlights-ed516355.pth'
28
 
29
+ model, cfg = None, None
 
 
 
 
 
 
 
 
 
30
 
31
  def convert_time(seconds):
32
  minutes, seconds = divmod(round(max(seconds, 0)), 60)
33
  return f'{minutes:02d}:{seconds:02d}'
34
 
 
35
  def load_video(video_path, cfg):
36
  decord.bridge.set_bridge('torch')
 
37
  vr = VideoReader(video_path)
38
  stride = vr.get_avg_fps() / cfg.data.val.fps
39
  fm_idx = [min(round(i), len(vr) - 1) for i in np.arange(0, len(vr), stride).tolist()]
40
  video = vr.get_batch(fm_idx).permute(0, 3, 1, 2).float() / 255
 
41
  size = 336 if '336px' in cfg.model.arch else 224
42
  h, w = video.size(-2), video.size(-1)
43
  s = min(h, w)
 
45
  video = video[..., x:x + s, y:y + s]
46
  video = F.resize(video, size=(size, size))
47
  video = F.normalize(video, (0.481, 0.459, 0.408), (0.269, 0.261, 0.276))
48
+ return video.reshape(video.size(0), -1).unsqueeze(0)
 
 
 
49
 
50
  def init_model(config, checkpoint):
51
  cfg = nncore.Config.from_file(config)
52
  cfg.model.init = True
 
 
 
 
53
  model = build_model(cfg.model, dist=False).eval()
54
  model = load_checkpoint(model, checkpoint, warning=False)
 
55
  return model, cfg
56
 
57
+ def process_video(video_path: str, query: str, model, cfg) -> dict:
58
+ if not query:
59
+ raise ValueError("Text query cannot be empty.")
 
 
60
  try:
61
+ video = load_video(video_path, cfg)
62
+ except Exception as e:
63
+ raise ValueError(f"Failed to load video: {str(e)}")
 
64
  query = clip.tokenize(query, truncate=True)
 
65
  device = next(model.parameters()).device
66
  data = dict(video=video.to(device), query=query.to(device), fps=[cfg.data.val.fps])
 
67
  with torch.inference_mode():
68
  pred = model(data)
 
69
  mr = pred['_out']['boundary'][:5].cpu().tolist()
70
  mr = [[convert_time(p[0]), convert_time(p[1]), round(p[2], 2)] for p in mr]
 
71
  hd = pred['_out']['saliency'].cpu()
72
  hd = ((hd - hd.min()) / (hd.max() - hd.min()) * 0.9 + 0.05).tolist()
73
+ hd = [{"x": i * 2, "y": y} for i, y in enumerate(hd)]
74
+ return {"moment_retrieval": mr, "highlight_detection": hd}
 
 
75
 
76
+ @app.on_event("startup")
77
+ async def startup_event():
78
+ global model, cfg
79
+ model, cfg = init_model(CONFIG, WEIGHT)
80
+ print("Model loaded successfully.")
81
 
82
+ @app.post("/predict")
83
+ async def predict(video: UploadFile = File(...), query: str = Form(...)):
84
+ try:
85
+ if not video.content_type.startswith("video/"):
86
+ raise HTTPException(status_code=400, detail="Invalid file type. Please upload a video.")
87
+ if not query.strip():
88
+ raise HTTPException(status_code=400, detail="Text query cannot be empty.")
89
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") as temp_file:
90
+ temp_file.write(await video.read())
91
+ temp_file_path = temp_file.name
92
+ try:
93
+ result = process_video(temp_file_path, query, model, cfg)
94
+ return result
95
+ finally:
96
+ os.unlink(temp_file_path)
97
+ except ValueError as e:
98
+ raise HTTPException(status_code=400, detail=str(e))
99
+ except Exception as e:
100
+ raise HTTPException(status_code=500, detail=f"Server error: {str(e)}")
101
+
102
+ @app.get("/health")
103
+ async def health():
104
+ return {"status": "healthy"}