lucas-ventura commited on
Commit
0d862c9
·
verified ·
1 Parent(s): 7f68df1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +314 -4
app.py CHANGED
@@ -1,11 +1,321 @@
 
 
 
 
1
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
 
 
 
 
 
 
 
 
 
 
3
 
4
- def greet(name):
5
- return "Hello " + name + "!"
6
 
 
 
 
 
 
7
 
8
- demo = gr.Interface(fn=greet, inputs="textbox", outputs="textbox")
9
 
10
  if __name__ == "__main__":
11
- demo.launch()
 
 
1
+ import os
2
+ import tempfile
3
+ from pathlib import Path
4
+
5
  import gradio as gr
6
+ from llama_cookbook.inference.model_utils import load_model as load_model_llamarecipes
7
+ from llama_cookbook.inference.model_utils import load_peft_model
8
+ from transformers import AutoTokenizer
9
+
10
+ from src.data.single_video import SingleVideo
11
+ from src.data.utils_asr import PromptASR
12
+ from src.models.llama_inference import inference
13
+ from src.test.vidchapters import get_chapters
14
+ from src.utils import RankedLogger
15
+ from tools.download.models import download_model
16
+
17
+ log = RankedLogger(__name__, rank_zero_only=True)
18
+
19
+ # Set up proxies
20
+ # from urllib.request import getproxies
21
+ # proxies = getproxies()
22
+ # os.environ["HTTP_PROXY"] = os.environ["http_proxy"] = proxies["http"]
23
+ # os.environ["HTTPS_PROXY"] = os.environ["https_proxy"] = proxies["https"]
24
+ # os.environ["NO_PROXY"] = os.environ["no_proxy"] = "localhost, 127.0.0.1/8, ::1"
25
+
26
+ # Global variables to store loaded models
27
+ base_model = None
28
+ tokenizer = None
29
+ current_peft_model = None
30
+ inference_model = None
31
+
32
+ LLAMA_CKPT_PATH = "meta-llama/Llama-3.1-8B-Instruct"
33
+
34
+
35
+ def load_base_model():
36
+ """Load the base Llama model and tokenizer once at startup."""
37
+ global base_model, tokenizer
38
+
39
+ if base_model is None:
40
+ log.info(f"Loading base model: {LLAMA_CKPT_PATH}")
41
+ base_model = load_model_llamarecipes(
42
+ model_name=LLAMA_CKPT_PATH,
43
+ device_map="auto",
44
+ quantization=None,
45
+ use_fast_kernels=True,
46
+ )
47
+ base_model.eval()
48
+
49
+ tokenizer = AutoTokenizer.from_pretrained(LLAMA_CKPT_PATH)
50
+ tokenizer.pad_token = tokenizer.eos_token
51
+
52
+ log.info("Base model loaded successfully")
53
+
54
+
55
+ class FastLlamaInference:
56
+ def __init__(
57
+ self,
58
+ model,
59
+ add_special_tokens: bool = True,
60
+ temperature: float = 1.0,
61
+ max_new_tokens: int = 1024,
62
+ top_p: float = 1.0,
63
+ top_k: int = 50,
64
+ use_cache: bool = True,
65
+ max_padding_length: int = None,
66
+ do_sample: bool = False,
67
+ min_length: int = None,
68
+ repetition_penalty: float = 1.0,
69
+ length_penalty: int = 1,
70
+ max_prompt_tokens: int = 35_000,
71
+ ):
72
+ self.model = model
73
+ self.tokenizer = tokenizer
74
+ self.add_special_tokens = add_special_tokens
75
+ self.temperature = temperature
76
+ self.max_new_tokens = max_new_tokens
77
+ self.top_p = top_p
78
+ self.top_k = top_k
79
+ self.use_cache = use_cache
80
+ self.max_padding_length = max_padding_length
81
+ self.do_sample = do_sample
82
+ self.min_length = min_length
83
+ self.repetition_penalty = repetition_penalty
84
+ self.length_penalty = length_penalty
85
+ self.max_prompt_tokens = max_prompt_tokens
86
+
87
+ def __call__(self, prompt: str, **kwargs):
88
+ # Create a dict of default parameters from instance attributes
89
+ params = {
90
+ "model": self.model,
91
+ "tokenizer": self.tokenizer,
92
+ "prompt": prompt,
93
+ "add_special_tokens": self.add_special_tokens,
94
+ "temperature": self.temperature,
95
+ "max_new_tokens": self.max_new_tokens,
96
+ "top_p": self.top_p,
97
+ "top_k": self.top_k,
98
+ "use_cache": self.use_cache,
99
+ "max_padding_length": self.max_padding_length,
100
+ "do_sample": self.do_sample,
101
+ "min_length": self.min_length,
102
+ "repetition_penalty": self.repetition_penalty,
103
+ "length_penalty": self.length_penalty,
104
+ "max_prompt_tokens": self.max_prompt_tokens,
105
+ }
106
+
107
+ # Update with any overrides passed in kwargs
108
+ params.update(kwargs)
109
+
110
+ return inference(**params)
111
+
112
+
113
+ def load_peft(model_name: str = "asr-10k"):
114
+ """Load or switch PEFT model while reusing the base model."""
115
+ global base_model, current_peft_model, inference_model
116
+
117
+ # First make sure the base model is loaded
118
+ if base_model is None:
119
+ load_base_model()
120
+
121
+ # Only load a new PEFT model if it's different from the current one
122
+ if current_peft_model != model_name:
123
+ log.info(f"Loading PEFT model: {model_name}")
124
+ model_path = download_model(model_name)
125
+
126
+ if not Path(model_path).exists():
127
+ log.warning(f"PEFT model does not exist at {model_path}")
128
+ return False
129
+
130
+ # Apply the PEFT model to the base model
131
+ peft_model = load_peft_model(base_model, model_path)
132
+
133
+ peft_model.eval()
134
+
135
+ # Create the inference wrapper
136
+ inference_model = FastLlamaInference(model=peft_model)
137
+ current_peft_model = model_name
138
+
139
+ log.info(f"PEFT model {model_name} loaded successfully")
140
+ return True
141
+
142
+ # Model already loaded
143
+ return True
144
+
145
+
146
+ def download_from_url(url, output_path):
147
+ """Download a video from a URL using yt-dlp and save it to output_path."""
148
+ try:
149
+ # Import yt-dlp Python package
150
+ try:
151
+ import yt_dlp
152
+ except ImportError:
153
+ log.error("yt-dlp Python package is not installed")
154
+ return (
155
+ False,
156
+ "yt-dlp Python package is not installed. Please install it with 'pip install yt-dlp'.",
157
+ )
158
+
159
+ # Configure yt-dlp options
160
+ ydl_opts = {
161
+ "format": "best",
162
+ "outtmpl": str(output_path),
163
+ "noplaylist": True,
164
+ "quiet": True,
165
+ }
166
+
167
+ # Download the video
168
+ with yt_dlp.YoutubeDL(ydl_opts) as ydl:
169
+ ydl.download([url])
170
+
171
+ # Check if the download was successful
172
+ if not os.path.exists(output_path):
173
+ return (
174
+ False,
175
+ "Download completed but video file not found. Please check the URL.",
176
+ )
177
+
178
+ return True, None
179
+ except Exception as e:
180
+ error_msg = f"Error downloading video: {str(e)}"
181
+ log.error(error_msg)
182
+ return False, error_msg
183
+
184
+
185
+ def process_video(
186
+ video_file, video_url, model_name: str = "asr-10k", do_sample: bool = False
187
+ ):
188
+ """Process a video file or URL and generate chapters."""
189
+ progress = gr.Progress()
190
+ progress(0, desc="Starting...")
191
+
192
+ # Check if we have a valid input
193
+ if video_file is None and not video_url:
194
+ return "Please upload a video file or provide a URL."
195
+
196
+ # Load the PEFT model
197
+ progress(0.1, desc=f"Loading LoRA parameters from {model_name}...")
198
+ if not load_peft(model_name):
199
+ return "Failed to load model. Please try again."
200
+
201
+ # Create a temporary directory to save the uploaded or downloaded video
202
+ with tempfile.TemporaryDirectory() as temp_dir:
203
+ temp_video_path = Path(temp_dir) / "temp_video.mp4"
204
+
205
+ if video_file is not None:
206
+ # Using uploaded file
207
+ progress(0.2, desc="Processing uploaded video...")
208
+ with open(temp_video_path, "wb") as f:
209
+ f.write(video_file)
210
+ else:
211
+ # Using URL
212
+ progress(0.2, desc=f"Downloading video from URL: {video_url}...")
213
+ success, error_msg = download_from_url(video_url, temp_video_path)
214
+ if not success:
215
+ return f"Failed to download video: {error_msg}"
216
+
217
+ # Process the video
218
+ progress(0.3, desc="Extracting ASR transcript...")
219
+ single_video = SingleVideo(temp_video_path)
220
+ progress(0.4, desc="Creating prompt...")
221
+ prompt = PromptASR(chapters=single_video)
222
+
223
+ vid_id = single_video.video_ids[0]
224
+ progress(0.5, desc="Creating prompt...")
225
+ prompt = prompt.get_prompt_test(vid_id)
226
+
227
+ transcript = single_video.get_asr(vid_id)
228
+ prompt = prompt + transcript
229
+
230
+ progress(0.6, desc="Generating chapters with Chapter-Llama...")
231
+ _, chapters = get_chapters(
232
+ inference_model,
233
+ prompt,
234
+ max_new_tokens=1024,
235
+ do_sample=do_sample,
236
+ vid_id=vid_id,
237
+ )
238
+
239
+ # Format the output
240
+ progress(0.9, desc="Formatting results...")
241
+ output = ""
242
+ for timestamp, text in chapters.items():
243
+ output += f"{timestamp}: {text}\n"
244
+
245
+ progress(1.0, desc="Complete!")
246
+ return output
247
+
248
+
249
+ # Create the Gradio interface
250
+ with gr.Blocks(title="Chapter-Llama") as demo:
251
+ gr.Markdown("# Chapter-Llama")
252
+ gr.Markdown("## Chaptering in Hour-Long Videos with LLMs")
253
+ gr.Markdown(
254
+ "Upload a video file or provide a URL to generate chapters automatically."
255
+ )
256
+ gr.Markdown(
257
+ """
258
+ This demo is currently using only the audio data (ASR), without frame information.
259
+ We will add audio+captions functionality in the near future, which will improve
260
+ chapter generation by incorporating visual content.
261
+
262
+ - GitHub: [https://github.com/lucas-ventura/chapter-llama](https://github.com/lucas-ventura/chapter-llama)
263
+ - Website: [https://imagine.enpc.fr/~lucas.ventura/chapter-llama/](https://imagine.enpc.fr/~lucas.ventura/chapter-llama/)
264
+ """
265
+ )
266
+
267
+ with gr.Row():
268
+ with gr.Column():
269
+ with gr.Tab("Upload File"):
270
+ video_input = gr.File(
271
+ label="Upload Video or Audio File",
272
+ file_types=["video", "audio"],
273
+ type="binary",
274
+ )
275
+
276
+ with gr.Tab("Video URL"):
277
+ video_url_input = gr.Textbox(
278
+ label="YouTube or Video URL",
279
+ placeholder="https://youtube.com/watch?v=...",
280
+ )
281
+
282
+ model_dropdown = gr.Dropdown(
283
+ choices=["asr-10k", "asr-1k"],
284
+ value="asr-10k",
285
+ label="Select Model",
286
+ )
287
+ do_sample = gr.Checkbox(
288
+ label="Use random sampling", value=False, interactive=True
289
+ )
290
+ submit_btn = gr.Button("Generate Chapters")
291
+
292
+ with gr.Column():
293
+ status_area = gr.Markdown("**Status:** Ready to process video")
294
+ output_text = gr.Textbox(
295
+ label="Generated Chapters", lines=10, interactive=False
296
+ )
297
 
298
+ def update_status_and_process(video_file, video_url, model_name, do_sample):
299
+ if video_file is None and not video_url:
300
+ return (
301
+ "**Status:** No video uploaded or URL provided",
302
+ "Please upload a video file or provide a URL.",
303
+ )
304
+ else:
305
+ return "**Status:** Processing video...", process_video(
306
+ video_file, video_url, model_name, do_sample
307
+ )
308
 
309
+ # Load the base model at startup
310
+ load_base_model()
311
 
312
+ submit_btn.click(
313
+ fn=update_status_and_process,
314
+ inputs=[video_input, video_url_input, model_dropdown, do_sample],
315
+ outputs=[status_area, output_text],
316
+ )
317
 
 
318
 
319
  if __name__ == "__main__":
320
+ # Launch the Gradio app
321
+ demo.launch(share=True)