MisbahKhan commited on
Commit
30b99b4
·
verified ·
1 Parent(s): cd4c391

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +75 -119
app.py CHANGED
@@ -1,21 +1,25 @@
 
1
  import random
2
  from functools import partial
3
- import gradio as gr
4
- import torch
5
  import clip
6
  import decord
 
7
  import nncore
8
  import numpy as np
9
- import pandas as pd
10
  import torchvision.transforms.functional as F
11
  from decord import VideoReader
12
  from nncore.engine import load_checkpoint
13
  from nncore.nn import build_model
14
 
15
- TITLE = '🌀 R2-Tuning: Efficient Image-to-Video Transfer Learning'
 
 
16
  CONFIG = 'configs/qvhighlights/r2_tuning_qvhighlights.py'
17
  WEIGHT = 'https://huggingface.co/yeliudev/R2-Tuning/resolve/main/checkpoints/r2_tuning_qvhighlights-ed516355.pth'
18
 
 
19
  EXAMPLES = [
20
  ('data/gTAvxnQtjXM_60.0_210.0.mp4', 'A man in a white t shirt wearing a backpack is showing a nearby cathedral.'),
21
  ('data/pA6Z-qYhSNg_210.0_360.0.mp4', 'Different Facebook posts on transgender bathrooms are shown.'),
@@ -23,17 +27,22 @@ EXAMPLES = [
23
  ('data/ocLUzCNodj4_360.0_510.0.mp4', 'A woman stands in her bedroom in front of a mirror and talks.'),
24
  ('data/HkLfNhgP0TM_660.0_810.0.mp4', 'Woman lays down on the couch while talking to the camera.')
25
  ]
 
 
26
 
27
  def convert_time(seconds):
28
  minutes, seconds = divmod(round(max(seconds, 0)), 60)
29
  return f'{minutes:02d}:{seconds:02d}'
30
 
 
31
  def load_video(video_path, cfg):
32
  decord.bridge.set_bridge('torch')
 
33
  vr = VideoReader(video_path)
34
  stride = vr.get_avg_fps() / cfg.data.val.fps
35
  fm_idx = [min(round(i), len(vr) - 1) for i in np.arange(0, len(vr), stride).tolist()]
36
  video = vr.get_batch(fm_idx).permute(0, 3, 1, 2).float() / 255
 
37
  size = 336 if '336px' in cfg.model.arch else 224
38
  h, w = video.size(-2), video.size(-1)
39
  s = min(h, w)
@@ -41,134 +50,81 @@ def load_video(video_path, cfg):
41
  video = video[..., x:x + s, y:y + s]
42
  video = F.resize(video, size=(size, size))
43
  video = F.normalize(video, (0.481, 0.459, 0.408), (0.269, 0.261, 0.276))
44
- return video.reshape(video.size(0), -1).unsqueeze(0)
 
 
 
45
 
46
  def init_model(config, checkpoint):
47
  cfg = nncore.Config.from_file(config)
48
  cfg.model.init = True
 
49
  if checkpoint.startswith('http'):
50
  checkpoint = nncore.download(checkpoint, out_dir='checkpoints', verbose=False)
 
51
  model = build_model(cfg.model, dist=False).eval()
52
- return load_checkpoint(model, checkpoint, warning=False), cfg
 
 
 
53
 
54
  def main(video, query, model, cfg):
55
- if not video:
56
- raise gr.Error("Please upload a video.")
57
- if not query:
58
- raise gr.Error("Text query cannot be empty.")
59
  try:
60
  video = load_video(video, cfg)
61
- query = clip.tokenize(query, truncate=True)
62
- device = next(model.parameters()).device
63
- data = dict(video=video.to(device), query=query.to(device), fps=[cfg.data.val.fps])
64
- with torch.inference_mode():
65
- pred = model(data)
66
- mr = pred['_out']['boundary'][:5].cpu().tolist()
67
- mr = [[convert_time(p[0]), convert_time(p[1]), round(p[2], 2)] for p in mr]
68
- hd = pred['_out']['saliency'].cpu()
69
- hd = ((hd - hd.min()) / (hd.max() - hd.min()) * 0.9 + 0.05).tolist()
70
- hd = pd.DataFrame(dict(x=range(0, len(hd) * 2, 2), y=hd))
71
- gr.Info("Results generated successfully!")
72
- return mr, hd
73
- except Exception as e:
74
- raise gr.Error(f"Error processing request: {str(e)}")
 
 
 
 
 
 
75
 
76
  model, cfg = init_model(CONFIG, WEIGHT)
 
77
  fn = partial(main, model=model, cfg=cfg)
78
 
79
- # Custom CSS
80
- custom_css = """
81
- .block { padding: 2rem; }
82
- .input-card, .output-card {
83
- border: 1px solid #E5E7EB;
84
- border-radius: 8px;
85
- padding: 1rem;
86
- background: #FFFFFF;
87
- box-shadow: 0 2px 4px rgba(0,0,0,0.05);
88
- }
89
- .markdown-guide {
90
- background: #F1F5F9;
91
- padding: 1rem;
92
- border-radius: 8px;
93
- }
94
- .video-input {
95
- border-radius: 8px;
96
- overflow: hidden;
97
- border: 1px solid #E5E7EB;
98
- }
99
- .button-primary {
100
- transition: all 0.2s ease;
101
- }
102
- .button-primary:hover {
103
- transform: scale(1.05);
104
- box-shadow: 0 4px 8px rgba(0,0,0,0.1);
105
- }
106
- @media (max-width: 768px) {
107
- .block { padding: 1rem; }
108
- .input-card, .output-card { padding: 0.5rem; }
109
- h1 { font-size: 1.8rem; }
110
- }
111
- """
112
-
113
- # UI
114
- custom_theme = gr.themes.Base(
115
- primary_hue="blue",
116
- secondary_hue="gray",
117
- neutral_hue="zinc",
118
- radius_size="lg",
119
- text_size="md",
120
- font=["Inter", "Roboto", "sans-serif"],
121
- )
122
-
123
- TITLE_MD = '<h1 align="center" style="font-size: 2.5rem; font-weight: 700;">🌀 R<sup>2</sup>-Tuning: Image-to-Video Transfer Learning</h1>'
124
- DESCRIPTION_MD = '''
125
- <div style="text-align: center; font-size: 1.1rem; color: #4B5EAA;">
126
- R<sup>2</sup>-Tuning is a parameter-efficient method for video temporal grounding.
127
- Explore our <a href="https://arxiv.org/abs/2404.00801" style="color: #1D4ED8;">Tech Report</a>
128
- and <a href="https://github.com/yeliudev/R2-Tuning" style="color: #1D4ED8;">GitHub Repo</a>.
129
- </div>
130
- '''
131
- GUIDE_MD = '''
132
- ### 📋 User Guide
133
- 1. **Upload a video** or click "Random" to try a sample.
134
- 2. **Enter a text query** (5–15 words recommended).
135
- 3. **Click Submit** to view moment retrieval and highlight detection results.
136
- '''
137
-
138
- with gr.Blocks(title=TITLE, theme=custom_theme, css=custom_css) as demo:
139
- gr.Markdown(TITLE_MD, elem_classes="text-center")
140
- gr.Markdown(DESCRIPTION_MD, elem_classes="text-center")
141
- gr.Markdown(GUIDE_MD, elem_classes="markdown-guide")
142
-
143
- with gr.Row(variant="panel"):
144
- with gr.Column(scale=1, min_width=400):
145
- with gr.Group(elem_classes="input-card"):
146
- video = gr.Video(label='Upload Video', elem_classes="video-input", height=300)
147
- query = gr.Textbox(label='Text Query', placeholder="Enter a descriptive sentence (5-15 words)...")
148
  with gr.Row():
149
- random_btn = gr.Button(value='🔮 Random', variant="secondary")
150
- gr.ClearButton([video, query], value='🗑️ Reset', variant="secondary")
151
- submit_btn = gr.Button(value='🚀 Submit', variant="primary")
152
- with gr.Column(scale=1, min_width=400):
153
- with gr.Group(elem_classes="output-card"):
154
- mr = gr.DataFrame(
155
- headers=['Start Time', 'End Time', 'Score'],
156
- label='Moment Retrieval',
157
- elem_classes="result-table"
158
- )
159
- hd = gr.LinePlot(
160
- x='x',
161
- y='y',
162
- x_title='Time (seconds)',
163
- y_title='Saliency Score',
164
- label='Highlight Detection',
165
- color="#4B5EAA",
166
- show_label=True,
167
- height=250,
168
- tooltip=True,
169
- grid=True
170
- )
171
- random_btn.click(lambda: random.sample(EXAMPLES, 1)[0], None, [video, query])
172
- submit_btn.click(fn, [video, query], [mr, hd])
173
-
174
- demo.launch()
 
1
+
2
  import random
3
  from functools import partial
4
+
 
5
  import clip
6
  import decord
7
+ import gradio as gr
8
  import nncore
9
  import numpy as np
10
+ import torch
11
  import torchvision.transforms.functional as F
12
  from decord import VideoReader
13
  from nncore.engine import load_checkpoint
14
  from nncore.nn import build_model
15
 
16
+ import pandas as pd
17
+
18
+
19
  CONFIG = 'configs/qvhighlights/r2_tuning_qvhighlights.py'
20
  WEIGHT = 'https://huggingface.co/yeliudev/R2-Tuning/resolve/main/checkpoints/r2_tuning_qvhighlights-ed516355.pth'
21
 
22
+ # yapf:disable
23
  EXAMPLES = [
24
  ('data/gTAvxnQtjXM_60.0_210.0.mp4', 'A man in a white t shirt wearing a backpack is showing a nearby cathedral.'),
25
  ('data/pA6Z-qYhSNg_210.0_360.0.mp4', 'Different Facebook posts on transgender bathrooms are shown.'),
 
27
  ('data/ocLUzCNodj4_360.0_510.0.mp4', 'A woman stands in her bedroom in front of a mirror and talks.'),
28
  ('data/HkLfNhgP0TM_660.0_810.0.mp4', 'Woman lays down on the couch while talking to the camera.')
29
  ]
30
+ # yapf:enable
31
+
32
 
33
  def convert_time(seconds):
34
  minutes, seconds = divmod(round(max(seconds, 0)), 60)
35
  return f'{minutes:02d}:{seconds:02d}'
36
 
37
+
38
  def load_video(video_path, cfg):
39
  decord.bridge.set_bridge('torch')
40
+
41
  vr = VideoReader(video_path)
42
  stride = vr.get_avg_fps() / cfg.data.val.fps
43
  fm_idx = [min(round(i), len(vr) - 1) for i in np.arange(0, len(vr), stride).tolist()]
44
  video = vr.get_batch(fm_idx).permute(0, 3, 1, 2).float() / 255
45
+
46
  size = 336 if '336px' in cfg.model.arch else 224
47
  h, w = video.size(-2), video.size(-1)
48
  s = min(h, w)
 
50
  video = video[..., x:x + s, y:y + s]
51
  video = F.resize(video, size=(size, size))
52
  video = F.normalize(video, (0.481, 0.459, 0.408), (0.269, 0.261, 0.276))
53
+ video = video.reshape(video.size(0), -1).unsqueeze(0)
54
+
55
+ return video
56
+
57
 
58
  def init_model(config, checkpoint):
59
  cfg = nncore.Config.from_file(config)
60
  cfg.model.init = True
61
+
62
  if checkpoint.startswith('http'):
63
  checkpoint = nncore.download(checkpoint, out_dir='checkpoints', verbose=False)
64
+
65
  model = build_model(cfg.model, dist=False).eval()
66
+ model = load_checkpoint(model, checkpoint, warning=False)
67
+
68
+ return model, cfg
69
+
70
 
71
  def main(video, query, model, cfg):
72
+ if len(query) == 0:
73
+ raise gr.Error('Text query can not be empty.')
74
+
 
75
  try:
76
  video = load_video(video, cfg)
77
+ except Exception:
78
+ raise gr.Error('Failed to load the video.')
79
+
80
+ query = clip.tokenize(query, truncate=True)
81
+
82
+ device = next(model.parameters()).device
83
+ data = dict(video=video.to(device), query=query.to(device), fps=[cfg.data.val.fps])
84
+
85
+ with torch.inference_mode():
86
+ pred = model(data)
87
+
88
+ mr = pred['_out']['boundary'][:5].cpu().tolist()
89
+ mr = [[convert_time(p[0]), convert_time(p[1]), round(p[2], 2)] for p in mr]
90
+
91
+ hd = pred['_out']['saliency'].cpu()
92
+ hd = ((hd - hd.min()) / (hd.max() - hd.min()) * 0.9 + 0.05).tolist()
93
+ hd = pd.DataFrame(dict(x=range(0, len(hd) * 2, 2), y=hd))
94
+
95
+ return mr, hd
96
+
97
 
98
  model, cfg = init_model(CONFIG, WEIGHT)
99
+
100
  fn = partial(main, model=model, cfg=cfg)
101
 
102
+ with gr.Blocks(title=TITLE) as demo:
103
+ gr.Markdown(TITLE_MD)
104
+ gr.Markdown(DESCRIPTION_MD)
105
+ gr.Markdown(GUIDE_MD)
106
+
107
+ with gr.Row():
108
+ with gr.Column():
109
+ video = gr.Video(label='Video')
110
+ query = gr.Textbox(label='Text Query')
111
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
  with gr.Row():
113
+ random_btn = gr.Button(value='🔮 Random')
114
+ gr.ClearButton([video, query], value='🗑️ Reset')
115
+ submit_btn = gr.Button(value='🚀 Submit')
116
+
117
+ with gr.Column():
118
+ mr = gr.DataFrame(
119
+ headers=['Start Time', 'End Time', 'Score'], label='Moment Retrieval')
120
+ hd = gr.LinePlot(
121
+ x='x',
122
+ y='y',
123
+ x_title='Time (seconds)',
124
+ y_title='Saliency Score',
125
+ label='Highlight Detection')
126
+
127
+ random_btn.click(lambda: random.sample(EXAMPLES, 1)[0], None, [video, query])
128
+ submit_btn.click(fn, [video, query], [mr, hd])
129
+
130
+ demo.launch()