Spaces:
Running
Running
import os | |
import hashlib | |
import requests | |
import numpy as np | |
from PIL import Image | |
import decord | |
from decord import VideoReader, cpu | |
import torch | |
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor | |
from qwen_vl_utils import process_vision_info | |
import gradio as gr | |
# --------------------------------------------------- | |
# 1. Set Up Device: Use Apple's MPS if available, else CPU | |
# --------------------------------------------------- | |
device = "mps" if torch.backends.mps.is_available() else "cpu" | |
print(f"Using device: {device}") | |
# For MPS, we can try using float16 to reduce memory usage. | |
torch_dtype = torch.float16 if device == "mps" else torch.float32 | |
# --------------------------------------------------- | |
# 2. Initialize the Qwen 2.5 VL Model (3B) for Local Use | |
# --------------------------------------------------- | |
model_path = "Qwen/Qwen2.5-VL-3B-Instruct" | |
model = Qwen2_5_VLForConditionalGeneration.from_pretrained( | |
model_path, | |
torch_dtype=torch_dtype | |
) | |
model.to(device) | |
processor = AutoProcessor.from_pretrained(model_path) | |
# --------------------------------------------------- | |
# 3. Utility Functions for Video Processing | |
# --------------------------------------------------- | |
def download_video(url, dest_path): | |
""" | |
Downloads a video from a URL. | |
(This function is kept here if you ever need to download via URL.) | |
""" | |
response = requests.get(url, stream=True) | |
with open(dest_path, 'wb') as f: | |
for chunk in response.iter_content(chunk_size=8096): | |
f.write(chunk) | |
print(f"Video downloaded to {dest_path}") | |
def get_video_frames(video_path, num_frames=64, cache_dir='.cache'): | |
""" | |
Extract frames and timestamps from a video file. | |
If video_path is a URL, it downloads it; otherwise it assumes a local file. | |
Caching is used to avoid re-processing. | |
""" | |
os.makedirs(cache_dir, exist_ok=True) | |
video_hash = hashlib.md5(video_path.encode('utf-8')).hexdigest() | |
# If the path starts with 'http', download the file. | |
if video_path.startswith("http"): | |
video_file_path = os.path.join(cache_dir, f"{video_hash}.mp4") | |
if not os.path.exists(video_file_path): | |
print("Downloading video using requests...") | |
download_video(video_path, video_file_path) | |
else: | |
video_file_path = video_path | |
frames_cache_file = os.path.join(cache_dir, f"{video_hash}_{num_frames}_frames.npy") | |
timestamps_cache_file = os.path.join(cache_dir, f"{video_hash}_{num_frames}_timestamps.npy") | |
if os.path.exists(frames_cache_file) and os.path.exists(timestamps_cache_file): | |
frames = np.load(frames_cache_file) | |
timestamps = np.load(timestamps_cache_file) | |
return video_file_path, frames, timestamps | |
# Load video using decord | |
vr = VideoReader(video_file_path, ctx=cpu(0)) | |
total_frames = len(vr) | |
indices = np.linspace(0, total_frames - 1, num=num_frames, dtype=int) | |
frames = vr.get_batch(indices).asnumpy() | |
timestamps = np.array([vr.get_frame_timestamp(idx) for idx in indices]) | |
# Cache the frames and timestamps | |
np.save(frames_cache_file, frames) | |
np.save(timestamps_cache_file, timestamps) | |
return video_file_path, frames, timestamps | |
# --------------------------------------------------- | |
# 4. Inference Function Using Qwen 2.5 VL (3B) | |
# --------------------------------------------------- | |
def inference(video_path, prompt, max_new_tokens=2048, total_pixels=20480 * 28 * 28, min_pixels=16 * 28 * 28): | |
""" | |
Prepares the input with the prompt and video metadata, | |
processes the video inputs, and runs inference through the model. | |
""" | |
messages = [ | |
{"role": "system", "content": "You are a helpful assistant."}, | |
{"role": "user", "content": [ | |
{"type": "text", "text": prompt}, | |
{"video": video_path, "total_pixels": total_pixels, "min_pixels": min_pixels}, | |
]}, | |
] | |
text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) | |
image_inputs, video_inputs, video_kwargs = process_vision_info([messages], return_video_kwargs=True) | |
fps_inputs = video_kwargs["fps"] | |
inputs = processor( | |
text=[text], | |
images=image_inputs, | |
videos=video_inputs, | |
fps=fps_inputs, | |
padding=True, | |
return_tensors="pt" | |
) | |
# Move inputs to our chosen device (MPS or CPU) | |
inputs = inputs.to(device) | |
output_ids = model.generate(**inputs, max_new_tokens=max_new_tokens) | |
generated_ids = [output_ids[len(input_ids):] for input_ids, output_ids in zip(inputs.input_ids, output_ids)] | |
output_text = processor.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True) | |
return output_text[0] | |
# --------------------------------------------------- | |
# 5. Define Sample Prompts | |
# --------------------------------------------------- | |
sample_prompts = [ | |
"Please analyze the video and split it into chapters with timestamps and descriptive titles in the format 'mm:ss Title'.", | |
"Provide a breakdown of the video's content by segment, including starting times and summaries.", | |
"Segment the video into logical chapters and output the start time and a brief description for each chapter.", | |
] | |
# --------------------------------------------------- | |
# 6. Main Processing Function for the Gradio Interface | |
# --------------------------------------------------- | |
def process_video(video_file, custom_prompt, sample_prompt): | |
""" | |
Called when the user clicks 'Process Video'. | |
Uses a custom prompt (if provided) or the sample prompt. | |
Processes the uploaded video and runs inference. | |
""" | |
final_prompt = custom_prompt.strip() if custom_prompt.strip() != "" else sample_prompt | |
try: | |
# Here, video_file is the local file path from the uploader. | |
video_path, frames, timestamps = get_video_frames(video_file, num_frames=64) | |
except Exception as e: | |
return f"Error processing video: {str(e)}" | |
try: | |
output = inference(video_path, final_prompt) | |
except Exception as e: | |
return f"Error during inference: {str(e)}" | |
return output | |
# --------------------------------------------------- | |
# 7. Build the Gradio Interface for Local Use | |
# --------------------------------------------------- | |
with gr.Blocks() as demo: | |
gr.Markdown("# Video Chapter Splitter using Qwen 2.5 VL (3B) on Mac") | |
gr.Markdown("Upload a video file and either type a custom prompt or select one of the sample prompts. Then click **Process Video** to generate the chapter breakdown.") | |
with gr.Row(): | |
video_input = gr.Video(label="Upload Video") | |
with gr.Row(): | |
custom_prompt_input = gr.Textbox(label="Custom Prompt", placeholder="Enter custom prompt (optional)...", lines=2) | |
with gr.Row(): | |
sample_prompt_input = gr.Dropdown(label="Sample Prompts", choices=sample_prompts, value=sample_prompts[0]) | |
output_text = gr.Textbox(label="Output", lines=10) | |
run_button = gr.Button("Process Video") | |
run_button.click(fn=process_video, inputs=[video_input, custom_prompt_input, sample_prompt_input], outputs=output_text) | |
if __name__ == "__main__": | |
demo.launch() |