Spaces:
Running
Running
import os | |
import hashlib | |
import requests | |
import numpy as np | |
from PIL import Image | |
import decord | |
from decord import VideoReader, cpu | |
import torch | |
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor | |
from qwen_vl_utils import process_vision_info | |
import gradio as gr | |
# Removed pytube since we no longer download from YouTube | |
# ---------------------------------------- | |
# 1. Initialize the Qwen 2.5 VL Model (3B) for CPU-only | |
# ---------------------------------------- | |
model_path = "Qwen/Qwen2.5-VL-3B-Instruct" | |
model = Qwen2_5_VLForConditionalGeneration.from_pretrained( | |
model_path, | |
torch_dtype=torch.float16 # use float16 on CPU if desired, else use float32 | |
# Removed attn_implementation and device_map for CPU-only deployment | |
) | |
processor = AutoProcessor.from_pretrained(model_path) | |
# ------------------------------------------------- | |
# 2. Define Utility Functions for Video Processing | |
# ------------------------------------------------- | |
def download_video(url, dest_path): | |
""" | |
Download a non-YouTube video using requests. | |
(This function is retained if you need it later.) | |
""" | |
response = requests.get(url, stream=True) | |
with open(dest_path, 'wb') as f: | |
for chunk in response.iter_content(chunk_size=8096): | |
f.write(chunk) | |
print(f"Video downloaded to {dest_path}") | |
def get_video_frames(video_path, num_frames=16, cache_dir='.cache'): | |
""" | |
Extract frames and timestamps from a video file. | |
If the video_path is a URL, it will download it. | |
For local files (including uploaded videos), it processes directly. | |
Uses caching to avoid repeated processing. | |
""" | |
os.makedirs(cache_dir, exist_ok=True) | |
video_hash = hashlib.md5(video_path.encode('utf-8')).hexdigest() | |
# If video_path starts with 'http', attempt to download | |
if video_path.startswith('http'): | |
video_file_path = os.path.join(cache_dir, f'{video_hash}.mp4') | |
if not os.path.exists(video_file_path): | |
print("Downloading video using requests...") | |
download_video(video_path, video_file_path) | |
else: | |
# For local files (uploaded videos), use the provided path directly. | |
video_file_path = video_path | |
# Check for cached frames | |
frames_cache_file = os.path.join(cache_dir, f'{video_hash}_{num_frames}_frames.npy') | |
timestamps_cache_file = os.path.join(cache_dir, f'{video_hash}_{num_frames}_timestamps.npy') | |
if os.path.exists(frames_cache_file) and os.path.exists(timestamps_cache_file): | |
frames = np.load(frames_cache_file) | |
timestamps = np.load(timestamps_cache_file) | |
return video_file_path, frames, timestamps | |
# Read video using decord | |
vr = VideoReader(video_file_path, ctx=cpu(0)) | |
total_frames = len(vr) | |
indices = np.linspace(0, total_frames - 1, num=num_frames, dtype=int) | |
frames = vr.get_batch(indices).asnumpy() | |
timestamps = np.array([vr.get_frame_timestamp(idx) for idx in indices]) | |
# Save to cache | |
np.save(frames_cache_file, frames) | |
np.save(timestamps_cache_file, timestamps) | |
return video_file_path, frames, timestamps | |
# -------------------------------------------------------- | |
# 3. Inference Function Using Qwen 2.5 VL to Process the Video | |
# -------------------------------------------------------- | |
def inference(video_path, prompt, max_new_tokens=2048, total_pixels=20480 * 28 * 28, min_pixels=16 * 28 * 28): | |
""" | |
Prepares the input messages with the prompt and video metadata, | |
processes the video inputs, and runs inference through the model. | |
""" | |
messages = [ | |
{"role": "system", "content": "You are a helpful assistant."}, | |
{"role": "user", "content": [ | |
{"type": "text", "text": prompt}, | |
{"video": video_path, "total_pixels": total_pixels, "min_pixels": min_pixels}, | |
]}, | |
] | |
text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) | |
image_inputs, video_inputs, video_kwargs = process_vision_info([messages], return_video_kwargs=True) | |
fps_inputs = video_kwargs['fps'] | |
inputs = processor( | |
text=[text], | |
images=image_inputs, | |
videos=video_inputs, | |
fps=fps_inputs, | |
padding=True, | |
return_tensors="pt" | |
) | |
# In CPU-only mode, we use the default device (no .to('cuda')) | |
output_ids = model.generate(**inputs, max_new_tokens=max_new_tokens) | |
generated_ids = [output_ids[len(input_ids):] for input_ids, output_ids in zip(inputs.input_ids, output_ids)] | |
output_text = processor.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True) | |
return output_text[0] | |
# ------------------------------------------------- | |
# 4. Define Sample Prompts for Users | |
# ------------------------------------------------- | |
sample_prompts = [ | |
"Please analyze the video and split it into chapters with timestamps and descriptive titles in the format 'mm:ss Title'.", | |
"Provide a breakdown of the video's content by segment, including starting times and summaries.", | |
"Segment the video into logical chapters and output the start time and a brief description for each chapter.", | |
] | |
# ------------------------------------------------- | |
# 5. Main Processing Function for the Gradio Interface | |
# ------------------------------------------------- | |
def process_video(video_file, custom_prompt, sample_prompt): | |
""" | |
Called when the user clicks 'Process Video'. | |
Uses the custom prompt if provided; otherwise, uses the sample prompt. | |
Processes the uploaded video file and runs inference. | |
""" | |
final_prompt = custom_prompt.strip() if custom_prompt.strip() != "" else sample_prompt | |
try: | |
# video_file is expected to be a local file path from the uploader. | |
video_path, frames, timestamps = get_video_frames(video_file, num_frames=64) | |
except Exception as e: | |
return f"Error processing video: {str(e)}" | |
try: | |
output = inference(video_path, final_prompt) | |
except Exception as e: | |
return f"Error during inference: {str(e)}" | |
return output | |
# ------------------------------------------------- | |
# 6. Build the Gradio Interface | |
# ------------------------------------------------- | |
with gr.Blocks() as demo: | |
gr.Markdown("# Video Chapter Splitter using Qwen 2.5 VL (3B) on CPU") | |
gr.Markdown("Upload a video file and either type a custom prompt or select one of the sample prompts. Then click **Process Video** to generate the chapter breakdown.") | |
with gr.Row(): | |
# Removed the source parameter here | |
video_input = gr.Video(label="Upload Video") | |
with gr.Row(): | |
custom_prompt_input = gr.Textbox(label="Custom Prompt", placeholder="Enter custom prompt (optional)...", lines=2) | |
with gr.Row(): | |
sample_prompt_input = gr.Dropdown(label="Sample Prompts", choices=sample_prompts, value=sample_prompts[0]) | |
output_text = gr.Textbox(label="Output", lines=10) | |
run_button = gr.Button("Process Video") | |
run_button.click(fn=process_video, inputs=[video_input, custom_prompt_input, sample_prompt_input], outputs=output_text) | |
if __name__ == "__main__": | |
demo.launch() |