MisbahKhan commited on
Commit
cd4c391
ยท
verified ยท
1 Parent(s): b0d1738

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +119 -81
app.py CHANGED
@@ -1,31 +1,21 @@
1
- # Copyright (c) Ye Liu. Licensed under the BSD 3-Clause License.
2
-
3
  import random
4
  from functools import partial
5
-
 
6
  import clip
7
  import decord
8
- import gradio as gr
9
  import nncore
10
  import numpy as np
11
- import torch
12
  import torchvision.transforms.functional as F
13
  from decord import VideoReader
14
  from nncore.engine import load_checkpoint
15
  from nncore.nn import build_model
16
 
17
- import pandas as pd
18
-
19
- TITLE = '๐ŸŒ€R2-Tuning: Efficient Image-to-Video Transfer Learning for Video Temporal Grounding'
20
-
21
- TITLE_MD = '<h1 align="center">๐ŸŒ€R<sup>2</sup>-Tuning: Efficient Image-to-Video Transfer Learning for Video Temporal Grounding</h1>'
22
- DESCRIPTION_MD = 'R<sup>2</sup>-Tuning is a parameter- and memory-efficient transfer learning method for video temporal grounding. Please find more details in our <a href="https://arxiv.org/abs/2404.00801" target="_blank">Tech Report</a> and <a href="https://github.com/yeliudev/R2-Tuning" target="_blank">GitHub Repo</a>.'
23
- GUIDE_MD = '### User Guide:\n1. Upload a video or click "random" to sample one.\n2. Input a text query. A good practice is to write a sentence with 5~15 words.\n3. Click "submit" and you\'ll see the moment retrieval and highlight detection results on the right.'
24
-
25
  CONFIG = 'configs/qvhighlights/r2_tuning_qvhighlights.py'
26
  WEIGHT = 'https://huggingface.co/yeliudev/R2-Tuning/resolve/main/checkpoints/r2_tuning_qvhighlights-ed516355.pth'
27
 
28
- # yapf:disable
29
  EXAMPLES = [
30
  ('data/gTAvxnQtjXM_60.0_210.0.mp4', 'A man in a white t shirt wearing a backpack is showing a nearby cathedral.'),
31
  ('data/pA6Z-qYhSNg_210.0_360.0.mp4', 'Different Facebook posts on transgender bathrooms are shown.'),
@@ -33,22 +23,17 @@ EXAMPLES = [
33
  ('data/ocLUzCNodj4_360.0_510.0.mp4', 'A woman stands in her bedroom in front of a mirror and talks.'),
34
  ('data/HkLfNhgP0TM_660.0_810.0.mp4', 'Woman lays down on the couch while talking to the camera.')
35
  ]
36
- # yapf:enable
37
-
38
 
39
  def convert_time(seconds):
40
  minutes, seconds = divmod(round(max(seconds, 0)), 60)
41
  return f'{minutes:02d}:{seconds:02d}'
42
 
43
-
44
  def load_video(video_path, cfg):
45
  decord.bridge.set_bridge('torch')
46
-
47
  vr = VideoReader(video_path)
48
  stride = vr.get_avg_fps() / cfg.data.val.fps
49
  fm_idx = [min(round(i), len(vr) - 1) for i in np.arange(0, len(vr), stride).tolist()]
50
  video = vr.get_batch(fm_idx).permute(0, 3, 1, 2).float() / 255
51
-
52
  size = 336 if '336px' in cfg.model.arch else 224
53
  h, w = video.size(-2), video.size(-1)
54
  s = min(h, w)
@@ -56,81 +41,134 @@ def load_video(video_path, cfg):
56
  video = video[..., x:x + s, y:y + s]
57
  video = F.resize(video, size=(size, size))
58
  video = F.normalize(video, (0.481, 0.459, 0.408), (0.269, 0.261, 0.276))
59
- video = video.reshape(video.size(0), -1).unsqueeze(0)
60
-
61
- return video
62
-
63
 
64
  def init_model(config, checkpoint):
65
  cfg = nncore.Config.from_file(config)
66
  cfg.model.init = True
67
-
68
  if checkpoint.startswith('http'):
69
  checkpoint = nncore.download(checkpoint, out_dir='checkpoints', verbose=False)
70
-
71
  model = build_model(cfg.model, dist=False).eval()
72
- model = load_checkpoint(model, checkpoint, warning=False)
73
-
74
- return model, cfg
75
-
76
 
77
  def main(video, query, model, cfg):
78
- if len(query) == 0:
79
- raise gr.Error('Text query can not be empty.')
80
-
 
81
  try:
82
  video = load_video(video, cfg)
83
- except Exception:
84
- raise gr.Error('Failed to load the video.')
85
-
86
- query = clip.tokenize(query, truncate=True)
87
-
88
- device = next(model.parameters()).device
89
- data = dict(video=video.to(device), query=query.to(device), fps=[cfg.data.val.fps])
90
-
91
- with torch.inference_mode():
92
- pred = model(data)
93
-
94
- mr = pred['_out']['boundary'][:5].cpu().tolist()
95
- mr = [[convert_time(p[0]), convert_time(p[1]), round(p[2], 2)] for p in mr]
96
-
97
- hd = pred['_out']['saliency'].cpu()
98
- hd = ((hd - hd.min()) / (hd.max() - hd.min()) * 0.9 + 0.05).tolist()
99
- hd = pd.DataFrame(dict(x=range(0, len(hd) * 2, 2), y=hd))
100
-
101
- return mr, hd
102
-
103
 
104
  model, cfg = init_model(CONFIG, WEIGHT)
105
-
106
  fn = partial(main, model=model, cfg=cfg)
107
 
108
- with gr.Blocks(title=TITLE) as demo:
109
- gr.Markdown(TITLE_MD)
110
- gr.Markdown(DESCRIPTION_MD)
111
- gr.Markdown(GUIDE_MD)
112
-
113
- with gr.Row():
114
- with gr.Column():
115
- video = gr.Video(label='Video')
116
- query = gr.Textbox(label='Text Query')
117
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
  with gr.Row():
119
- random_btn = gr.Button(value='๐Ÿ”ฎ Random')
120
- gr.ClearButton([video, query], value='๐Ÿ—‘๏ธ Reset')
121
- submit_btn = gr.Button(value='๐Ÿš€ Submit')
122
-
123
- with gr.Column():
124
- mr = gr.DataFrame(
125
- headers=['Start Time', 'End Time', 'Score'], label='Moment Retrieval')
126
- hd = gr.LinePlot(
127
- x='x',
128
- y='y',
129
- x_title='Time (seconds)',
130
- y_title='Saliency Score',
131
- label='Highlight Detection')
132
-
133
- random_btn.click(lambda: random.sample(EXAMPLES, 1)[0], None, [video, query])
134
- submit_btn.click(fn, [video, query], [mr, hd])
135
-
136
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
1
  import random
2
  from functools import partial
3
+ import gradio as gr
4
+ import torch
5
  import clip
6
  import decord
 
7
  import nncore
8
  import numpy as np
9
+ import pandas as pd
10
  import torchvision.transforms.functional as F
11
  from decord import VideoReader
12
  from nncore.engine import load_checkpoint
13
  from nncore.nn import build_model
14
 
15
+ TITLE = '๐ŸŒ€ R2-Tuning: Efficient Image-to-Video Transfer Learning'
 
 
 
 
 
 
 
16
  CONFIG = 'configs/qvhighlights/r2_tuning_qvhighlights.py'
17
  WEIGHT = 'https://huggingface.co/yeliudev/R2-Tuning/resolve/main/checkpoints/r2_tuning_qvhighlights-ed516355.pth'
18
 
 
19
  EXAMPLES = [
20
  ('data/gTAvxnQtjXM_60.0_210.0.mp4', 'A man in a white t shirt wearing a backpack is showing a nearby cathedral.'),
21
  ('data/pA6Z-qYhSNg_210.0_360.0.mp4', 'Different Facebook posts on transgender bathrooms are shown.'),
 
23
  ('data/ocLUzCNodj4_360.0_510.0.mp4', 'A woman stands in her bedroom in front of a mirror and talks.'),
24
  ('data/HkLfNhgP0TM_660.0_810.0.mp4', 'Woman lays down on the couch while talking to the camera.')
25
  ]
 
 
26
 
27
  def convert_time(seconds):
28
  minutes, seconds = divmod(round(max(seconds, 0)), 60)
29
  return f'{minutes:02d}:{seconds:02d}'
30
 
 
31
  def load_video(video_path, cfg):
32
  decord.bridge.set_bridge('torch')
 
33
  vr = VideoReader(video_path)
34
  stride = vr.get_avg_fps() / cfg.data.val.fps
35
  fm_idx = [min(round(i), len(vr) - 1) for i in np.arange(0, len(vr), stride).tolist()]
36
  video = vr.get_batch(fm_idx).permute(0, 3, 1, 2).float() / 255
 
37
  size = 336 if '336px' in cfg.model.arch else 224
38
  h, w = video.size(-2), video.size(-1)
39
  s = min(h, w)
 
41
  video = video[..., x:x + s, y:y + s]
42
  video = F.resize(video, size=(size, size))
43
  video = F.normalize(video, (0.481, 0.459, 0.408), (0.269, 0.261, 0.276))
44
+ return video.reshape(video.size(0), -1).unsqueeze(0)
 
 
 
45
 
46
  def init_model(config, checkpoint):
47
  cfg = nncore.Config.from_file(config)
48
  cfg.model.init = True
 
49
  if checkpoint.startswith('http'):
50
  checkpoint = nncore.download(checkpoint, out_dir='checkpoints', verbose=False)
 
51
  model = build_model(cfg.model, dist=False).eval()
52
+ return load_checkpoint(model, checkpoint, warning=False), cfg
 
 
 
53
 
54
  def main(video, query, model, cfg):
55
+ if not video:
56
+ raise gr.Error("Please upload a video.")
57
+ if not query:
58
+ raise gr.Error("Text query cannot be empty.")
59
  try:
60
  video = load_video(video, cfg)
61
+ query = clip.tokenize(query, truncate=True)
62
+ device = next(model.parameters()).device
63
+ data = dict(video=video.to(device), query=query.to(device), fps=[cfg.data.val.fps])
64
+ with torch.inference_mode():
65
+ pred = model(data)
66
+ mr = pred['_out']['boundary'][:5].cpu().tolist()
67
+ mr = [[convert_time(p[0]), convert_time(p[1]), round(p[2], 2)] for p in mr]
68
+ hd = pred['_out']['saliency'].cpu()
69
+ hd = ((hd - hd.min()) / (hd.max() - hd.min()) * 0.9 + 0.05).tolist()
70
+ hd = pd.DataFrame(dict(x=range(0, len(hd) * 2, 2), y=hd))
71
+ gr.Info("Results generated successfully!")
72
+ return mr, hd
73
+ except Exception as e:
74
+ raise gr.Error(f"Error processing request: {str(e)}")
 
 
 
 
 
 
75
 
76
  model, cfg = init_model(CONFIG, WEIGHT)
 
77
  fn = partial(main, model=model, cfg=cfg)
78
 
79
+ # Custom CSS
80
+ custom_css = """
81
+ .block { padding: 2rem; }
82
+ .input-card, .output-card {
83
+ border: 1px solid #E5E7EB;
84
+ border-radius: 8px;
85
+ padding: 1rem;
86
+ background: #FFFFFF;
87
+ box-shadow: 0 2px 4px rgba(0,0,0,0.05);
88
+ }
89
+ .markdown-guide {
90
+ background: #F1F5F9;
91
+ padding: 1rem;
92
+ border-radius: 8px;
93
+ }
94
+ .video-input {
95
+ border-radius: 8px;
96
+ overflow: hidden;
97
+ border: 1px solid #E5E7EB;
98
+ }
99
+ .button-primary {
100
+ transition: all 0.2s ease;
101
+ }
102
+ .button-primary:hover {
103
+ transform: scale(1.05);
104
+ box-shadow: 0 4px 8px rgba(0,0,0,0.1);
105
+ }
106
+ @media (max-width: 768px) {
107
+ .block { padding: 1rem; }
108
+ .input-card, .output-card { padding: 0.5rem; }
109
+ h1 { font-size: 1.8rem; }
110
+ }
111
+ """
112
+
113
+ # UI
114
+ custom_theme = gr.themes.Base(
115
+ primary_hue="blue",
116
+ secondary_hue="gray",
117
+ neutral_hue="zinc",
118
+ radius_size="lg",
119
+ text_size="md",
120
+ font=["Inter", "Roboto", "sans-serif"],
121
+ )
122
+
123
+ TITLE_MD = '<h1 align="center" style="font-size: 2.5rem; font-weight: 700;">๐ŸŒ€ R<sup>2</sup>-Tuning: Image-to-Video Transfer Learning</h1>'
124
+ DESCRIPTION_MD = '''
125
+ <div style="text-align: center; font-size: 1.1rem; color: #4B5EAA;">
126
+ R<sup>2</sup>-Tuning is a parameter-efficient method for video temporal grounding.
127
+ Explore our <a href="https://arxiv.org/abs/2404.00801" style="color: #1D4ED8;">Tech Report</a>
128
+ and <a href="https://github.com/yeliudev/R2-Tuning" style="color: #1D4ED8;">GitHub Repo</a>.
129
+ </div>
130
+ '''
131
+ GUIDE_MD = '''
132
+ ### ๐Ÿ“‹ User Guide
133
+ 1. **Upload a video** or click "Random" to try a sample.
134
+ 2. **Enter a text query** (5โ€“15 words recommended).
135
+ 3. **Click Submit** to view moment retrieval and highlight detection results.
136
+ '''
137
+
138
+ with gr.Blocks(title=TITLE, theme=custom_theme, css=custom_css) as demo:
139
+ gr.Markdown(TITLE_MD, elem_classes="text-center")
140
+ gr.Markdown(DESCRIPTION_MD, elem_classes="text-center")
141
+ gr.Markdown(GUIDE_MD, elem_classes="markdown-guide")
142
+
143
+ with gr.Row(variant="panel"):
144
+ with gr.Column(scale=1, min_width=400):
145
+ with gr.Group(elem_classes="input-card"):
146
+ video = gr.Video(label='Upload Video', elem_classes="video-input", height=300)
147
+ query = gr.Textbox(label='Text Query', placeholder="Enter a descriptive sentence (5-15 words)...")
148
  with gr.Row():
149
+ random_btn = gr.Button(value='๐Ÿ”ฎ Random', variant="secondary")
150
+ gr.ClearButton([video, query], value='๐Ÿ—‘๏ธ Reset', variant="secondary")
151
+ submit_btn = gr.Button(value='๐Ÿš€ Submit', variant="primary")
152
+ with gr.Column(scale=1, min_width=400):
153
+ with gr.Group(elem_classes="output-card"):
154
+ mr = gr.DataFrame(
155
+ headers=['Start Time', 'End Time', 'Score'],
156
+ label='Moment Retrieval',
157
+ elem_classes="result-table"
158
+ )
159
+ hd = gr.LinePlot(
160
+ x='x',
161
+ y='y',
162
+ x_title='Time (seconds)',
163
+ y_title='Saliency Score',
164
+ label='Highlight Detection',
165
+ color="#4B5EAA",
166
+ show_label=True,
167
+ height=250,
168
+ tooltip=True,
169
+ grid=True
170
+ )
171
+ random_btn.click(lambda: random.sample(EXAMPLES, 1)[0], None, [video, query])
172
+ submit_btn.click(fn, [video, query], [mr, hd])
173
+
174
+ demo.launch()