import os import hashlib import requests import numpy as np from PIL import Image import decord from decord import VideoReader, cpu import torch from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor from qwen_vl_utils import process_vision_info import gradio as gr # --------------------------------------------------- # 1. Set Up Device: Use Apple's MPS if available, else CPU # --------------------------------------------------- device = "mps" if torch.backends.mps.is_available() else "cpu" print(f"Using device: {device}") # For MPS, we can try using float16 to reduce memory usage. torch_dtype = torch.float16 if device == "mps" else torch.float32 # --------------------------------------------------- # 2. Initialize the Qwen 2.5 VL Model (3B) for Local Use # --------------------------------------------------- model_path = "Qwen/Qwen2.5-VL-3B-Instruct" model = Qwen2_5_VLForConditionalGeneration.from_pretrained( model_path, torch_dtype=torch_dtype ) model.to(device) processor = AutoProcessor.from_pretrained(model_path) # --------------------------------------------------- # 3. Utility Functions for Video Processing # --------------------------------------------------- def download_video(url, dest_path): """ Downloads a video from a URL. (This function is kept here if you ever need to download via URL.) """ response = requests.get(url, stream=True) with open(dest_path, 'wb') as f: for chunk in response.iter_content(chunk_size=8096): f.write(chunk) print(f"Video downloaded to {dest_path}") def get_video_frames(video_path, num_frames=64, cache_dir='.cache'): """ Extract frames and timestamps from a video file. If video_path is a URL, it downloads it; otherwise it assumes a local file. Caching is used to avoid re-processing. """ os.makedirs(cache_dir, exist_ok=True) video_hash = hashlib.md5(video_path.encode('utf-8')).hexdigest() # If the path starts with 'http', download the file. if video_path.startswith("http"): video_file_path = os.path.join(cache_dir, f"{video_hash}.mp4") if not os.path.exists(video_file_path): print("Downloading video using requests...") download_video(video_path, video_file_path) else: video_file_path = video_path frames_cache_file = os.path.join(cache_dir, f"{video_hash}_{num_frames}_frames.npy") timestamps_cache_file = os.path.join(cache_dir, f"{video_hash}_{num_frames}_timestamps.npy") if os.path.exists(frames_cache_file) and os.path.exists(timestamps_cache_file): frames = np.load(frames_cache_file) timestamps = np.load(timestamps_cache_file) return video_file_path, frames, timestamps # Load video using decord vr = VideoReader(video_file_path, ctx=cpu(0)) total_frames = len(vr) indices = np.linspace(0, total_frames - 1, num=num_frames, dtype=int) frames = vr.get_batch(indices).asnumpy() timestamps = np.array([vr.get_frame_timestamp(idx) for idx in indices]) # Cache the frames and timestamps np.save(frames_cache_file, frames) np.save(timestamps_cache_file, timestamps) return video_file_path, frames, timestamps # --------------------------------------------------- # 4. Inference Function Using Qwen 2.5 VL (3B) # --------------------------------------------------- def inference(video_path, prompt, max_new_tokens=2048, total_pixels=20480 * 28 * 28, min_pixels=16 * 28 * 28): """ Prepares the input with the prompt and video metadata, processes the video inputs, and runs inference through the model. """ messages = [ {"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": [ {"type": "text", "text": prompt}, {"video": video_path, "total_pixels": total_pixels, "min_pixels": min_pixels}, ]}, ] text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) image_inputs, video_inputs, video_kwargs = process_vision_info([messages], return_video_kwargs=True) fps_inputs = video_kwargs["fps"] inputs = processor( text=[text], images=image_inputs, videos=video_inputs, fps=fps_inputs, padding=True, return_tensors="pt" ) # Move inputs to our chosen device (MPS or CPU) inputs = inputs.to(device) output_ids = model.generate(**inputs, max_new_tokens=max_new_tokens) generated_ids = [output_ids[len(input_ids):] for input_ids, output_ids in zip(inputs.input_ids, output_ids)] output_text = processor.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True) return output_text[0] # --------------------------------------------------- # 5. Define Sample Prompts # --------------------------------------------------- sample_prompts = [ "Please analyze the video and split it into chapters with timestamps and descriptive titles in the format 'mm:ss Title'.", "Provide a breakdown of the video's content by segment, including starting times and summaries.", "Segment the video into logical chapters and output the start time and a brief description for each chapter.", ] # --------------------------------------------------- # 6. Main Processing Function for the Gradio Interface # --------------------------------------------------- def process_video(video_file, custom_prompt, sample_prompt): """ Called when the user clicks 'Process Video'. Uses a custom prompt (if provided) or the sample prompt. Processes the uploaded video and runs inference. """ final_prompt = custom_prompt.strip() if custom_prompt.strip() != "" else sample_prompt try: # Here, video_file is the local file path from the uploader. video_path, frames, timestamps = get_video_frames(video_file, num_frames=64) except Exception as e: return f"Error processing video: {str(e)}" try: output = inference(video_path, final_prompt) except Exception as e: return f"Error during inference: {str(e)}" return output # --------------------------------------------------- # 7. Build the Gradio Interface for Local Use # --------------------------------------------------- with gr.Blocks() as demo: gr.Markdown("# Video Chapter Splitter using Qwen 2.5 VL (3B) on Mac") gr.Markdown("Upload a video file and either type a custom prompt or select one of the sample prompts. Then click **Process Video** to generate the chapter breakdown.") with gr.Row(): video_input = gr.Video(label="Upload Video") with gr.Row(): custom_prompt_input = gr.Textbox(label="Custom Prompt", placeholder="Enter custom prompt (optional)...", lines=2) with gr.Row(): sample_prompt_input = gr.Dropdown(label="Sample Prompts", choices=sample_prompts, value=sample_prompts[0]) output_text = gr.Textbox(label="Output", lines=10) run_button = gr.Button("Process Video") run_button.click(fn=process_video, inputs=[video_input, custom_prompt_input, sample_prompt_input], outputs=output_text) if __name__ == "__main__": demo.launch()