Spaces:
Running
Running
import gradio as gr | |
import os | |
from llama_cpp import Llama | |
from qdrant_client import QdrantClient | |
from datasets import load_dataset | |
from sentence_transformers import SentenceTransformer | |
import tempfile | |
import uuid | |
import re | |
import subprocess | |
import traceback | |
QDRANT_COLLECTION_NAME = "video_frames" | |
VIDEO_SEGMENT_DURATION = 40 # Extract 40 seconds around the timestamp | |
# Load Secrets from Environment Variables | |
QDRANT_API_KEY = os.environ.get("QDRANT_API_KEY") | |
# Check for qdrant key | |
if not QDRANT_API_KEY: | |
print("Error: QDRANT_API_KEY environment variable not found.") | |
print("Please add your Qdrant API key as a secret named 'QDRANT_API_KEY' in your Hugging Face Space settings.") | |
raise ValueError("QDRANT_API_KEY environment variable not set.") | |
print("Initializing LLM...") | |
try: | |
llm = Llama.from_pretrained( | |
repo_id="m1tch/gemma-finetune-ai_class_gguf", | |
filename="gemma-3_ai_class.Q8_0.gguf", | |
n_gpu_layers=-1, | |
n_ctx=2048, | |
verbose=False | |
) | |
print("LLM initialized successfully.") | |
except Exception as e: | |
print(f"Error initializing LLM: {e}") | |
raise | |
print("Connecting to Qdrant...") | |
try: | |
qdrant_client = QdrantClient( | |
url="https://2c18d413-cbb5-441c-b060-4c8c2302dcde.us-east4-0.gcp.cloud.qdrant.io:6333/", | |
api_key=QDRANT_API_KEY, | |
timeout=60 | |
) | |
qdrant_client.get_collections() | |
print("Qdrant connection successful.") | |
except Exception as e: | |
print(f"Error connecting to Qdrant: {e}") | |
raise | |
print("Loading dataset stream...") | |
try: | |
# Load video dataset | |
dataset = load_dataset("aegean-ai/ai-lectures-spring-24", split="train", streaming=True) | |
print(f"Dataset loaded.") | |
except Exception as e: | |
print(f"Error loading dataset: {e}") | |
raise | |
try: | |
embedding_model = SentenceTransformer('all-MiniLM-L6-v2') | |
print("Sentence Transformer model loaded.") | |
except Exception as e: | |
print(f"Error loading Sentence Transformer model: {e}") | |
raise | |
def rag_query(client, collection_name, query_text, top_k=5, filter_condition=None): | |
""" | |
Test RAG by querying the vector database with text. Returns a dictionary with search results and metadata. | |
Uses the pre-loaded embedding_model. | |
""" | |
try: | |
query_vector = embedding_model.encode(query_text).tolist() | |
search_params = { | |
"collection_name": collection_name, | |
"query_vector": query_vector, | |
"limit": top_k, | |
"with_payload": True, | |
"with_vectors": False | |
} | |
if filter_condition: | |
search_params["filter"] = filter_condition | |
search_results = client.search(**search_params) | |
formatted_results = [] | |
for idx, result in enumerate(search_results): | |
formatted_results.append({ | |
"rank": idx + 1, | |
"score": result.score, | |
"video_id": result.payload.get("video_id"), | |
"timestamp": result.payload.get("timestamp"), | |
"subtitle": result.payload.get("subtitle"), | |
"frame_number": result.payload.get("frame_number") | |
}) | |
return { | |
"query": query_text, | |
"results": formatted_results, | |
"avg_score": sum(r.score for r in search_results) / len(search_results) if search_results else 0 | |
} | |
except Exception as e: | |
print(f"Error during RAG query: {e}") | |
traceback.print_exc() | |
return {"error": str(e), "query": query_text, "results": []} | |
def extract_video_segment(video_id, start_time, duration, dataset): | |
""" | |
Extracts a single video segment file path from the dataset stream. | |
Returns a single path suitable for Gradio or None on failure. | |
""" | |
target_id = str(video_id) | |
target_key_pattern = re.compile(r"videos/" + re.escape(target_id) + r"/" + re.escape(target_id)) | |
start_time = float(start_time) | |
duration = float(duration) | |
unique_id = str(uuid.uuid4()) | |
temp_dir = os.path.join(tempfile.gettempdir(), f"gradio_video_seg_{unique_id}") | |
os.makedirs(temp_dir, exist_ok=True) | |
temp_video_path_full = os.path.join(temp_dir, f"{target_id}_full_{unique_id}.mp4") | |
output_path_ffmpeg = os.path.join(temp_dir, f"output_ffmpeg_{unique_id}.mp4") | |
print(f"Attempting to extract segment for video_id={target_id}, start={start_time:.2f}, duration={duration:.2f}") | |
print(f"Looking for dataset key matching pattern: {target_key_pattern.pattern}") | |
print(f"Temporary directory: {temp_dir}") | |
found_sample = None | |
max_search_attempts = 1000 # Limit | |
print(f"Searching dataset stream for key matching pattern: {target_key_pattern.pattern}") | |
dataset_iterator = iter(dataset) | |
try: | |
# Find and save the full video from the stream | |
for i in range(max_search_attempts): | |
try: | |
sample = next(dataset_iterator) | |
if '__key__' in sample and 'mp4' in sample and target_key_pattern.match(sample['__key__']): | |
print(f"Found video key {sample['__key__']} after {i+1} iterations. Saving to {temp_video_path_full}...") | |
with open(temp_video_path_full, 'wb') as f: | |
f.write(sample['mp4']) | |
print(f"Video saved successfully ({os.path.getsize(temp_video_path_full)} bytes).") | |
found_sample = sample | |
break # Found the video | |
except StopIteration: | |
print("Reached end of dataset stream without finding the video within search limit.") | |
break | |
except Exception as e: | |
print(f"Warning: Error iterating dataset sample {i+1}: {e}") | |
if not found_sample or not os.path.exists(temp_video_path_full) or os.path.getsize(temp_video_path_full) == 0: | |
print(f"Could not find or save video with ID {target_id} from dataset stream.") | |
return None | |
# Process the saved video with FFmpeg | |
final_output_path = None | |
try: | |
cmd = [ | |
'ffmpeg', | |
'-y', | |
'-ss', str(start_time), | |
'-i', temp_video_path_full, | |
'-t', str(duration), | |
'-c:v', 'libx264', | |
'-profile:v', 'baseline', | |
'-level', '3.0', | |
'-preset', 'fast', | |
'-pix_fmt', 'yuv420p', | |
'-movflags', '+faststart', | |
'-c:a', 'aac', | |
'-b:a', '128k', | |
'-vf', f'select=gte(t,{start_time})', | |
'-vsync', 'vfr', | |
output_path_ffmpeg | |
] | |
print(f"Running FFmpeg command: {' '.join(cmd)}") | |
result = subprocess.run(cmd, capture_output=True, text=True, timeout=120) | |
if result.returncode == 0 and os.path.exists(output_path_ffmpeg) and os.path.getsize(output_path_ffmpeg) > 0: | |
print(f"FFmpeg processing successful. Output: {output_path_ffmpeg}") | |
final_output_path = output_path_ffmpeg | |
else: | |
print(f"FFmpeg error (Return Code: {result.returncode}):") | |
print(f"FFmpeg stdout:\n{result.stdout}") | |
print(f"FFmpeg stderr:\n{result.stderr}") | |
print("FFmpeg failed.") | |
final_output_path = None | |
except subprocess.TimeoutExpired: | |
print("FFmpeg command timed out.") | |
final_output_path = None | |
except FileNotFoundError: | |
print("Error: ffmpeg command not found. Make sure FFmpeg is installed.") | |
final_output_path = None | |
except Exception as e: | |
print(f"An unexpected error occurred during FFmpeg processing: {e}") | |
traceback.print_exc() | |
final_output_path = None | |
finally: | |
# Clean up temporary files | |
print(f"Cleaning up temporary directory: {temp_dir}") | |
if os.path.exists(temp_video_path_full): | |
try: | |
os.remove(temp_video_path_full) | |
print(f"Cleaned up temporary full video: {temp_video_path_full}") | |
except Exception as e: | |
print(f"Warning: Could not remove temporary file {temp_video_path_full}: {e}") | |
# Clean up failed FFmpeg output if it exists and wasn't the final path | |
if final_output_path != output_path_ffmpeg and os.path.exists(output_path_ffmpeg): | |
try: | |
os.remove(output_path_ffmpeg) | |
except Exception as e: | |
print(f"Warning: Could not remove failed ffmpeg output {output_path_ffmpeg}: {e}") | |
# Return the path of the successfully created segment or None | |
if final_output_path and os.path.exists(final_output_path): | |
print(f"Returning video segment path: {final_output_path}") | |
return final_output_path | |
else: | |
print("Video segment extraction failed.") | |
return None | |
def parse_llm_output(text): | |
""" | |
Parses the LLM's structured output using string manipulation. | |
""" | |
data = {} | |
print(f"\nDEBUG: Raw text input to parse_llm_output:\n---\n{text}\n---") | |
def extract_field(text, field_name): | |
start_marker_lower = "{" + field_name.lower() + ":" | |
start_index = text.lower().find(start_marker_lower) | |
if start_index != -1: | |
actual_marker_end = start_index + len(start_marker_lower) | |
end_index = text.find('}', actual_marker_end) | |
if end_index != -1: | |
value = text[actual_marker_end : end_index] | |
value = value.strip() | |
if value.startswith('[') and value.endswith(']'): | |
value = value[1:-1].strip() | |
value = value.strip('\'"“”') | |
return value.strip() | |
else: | |
print(f"Warning: Found '{{{field_name}:' marker but no closing '}}' found afterwards.") | |
else: | |
print(f"Warning: Marker '{{{field_name}:' not found in text.") | |
return None | |
# Extract fields | |
data['video_id'] = extract_field(text, 'Best Result') | |
data['timestamp'] = extract_field(text, 'Timestamp') | |
data['content'] = extract_field(text, 'Content') | |
data['reasoning'] = extract_field(text, 'Reasoning') | |
if data.get('timestamp'): | |
try: | |
float(data['timestamp']) | |
except ValueError: | |
print(f"Warning: Parsed timestamp '{data['timestamp']}' is not a valid number.") | |
data['timestamp'] = None | |
print(f"Parsed LLM output: {data}") | |
return data | |
def process_query_and_get_video(query_text): | |
""" | |
Orchestrates RAG, LLM query, parsing, and video extraction. | |
Returns only the video path or None. | |
""" | |
print(f"\n--- Processing query: '{query_text}' ---") | |
# RAG Query | |
print("Step 1: Performing RAG query...") | |
rag_results = rag_query(qdrant_client, QDRANT_COLLECTION_NAME, query_text) | |
if "error" in rag_results or not rag_results.get("results"): | |
error_msg = rag_results.get('error', 'No relevant segments found by RAG.') | |
print(f"RAG Error/No Results: {error_msg}") | |
# Return None for video output on RAG failure | |
return None | |
print(f"RAG query successful. Found {len(rag_results['results'])} results.") | |
# Format LLM Prompt | |
print("Step 2: Formatting prompt for LLM...") | |
results_for_llm = "\n".join([ | |
f"Rank: {r['rank']}, Score: {r['score']:.4f}, Video ID: {r['video_id']}, Timestamp: {r['timestamp']}, Subtitle: {r['subtitle']}" | |
for r in rag_results['results'] | |
]) | |
prompt = f"""You are tasked with selecting the most relevant information from a set of video subtitle segments to answer a query. | |
QUERY: "{query_text}" | |
Here are the relevant video segments found: | |
--- | |
{results_for_llm} | |
--- | |
For each result provided, evaluate how well it directly addresses the definition or explanation related to the query. Pay attention to: | |
1. Clarity of explanation | |
2. Relevance to the query | |
3. Completeness of information | |
From the provided results, select the SINGLE BEST match that most directly answers the query. | |
Format your response STRICTLY as follows, with each field on a new line: | |
{{Best Result: [video_id]}} | |
{{Timestamp: [timestamp]}} | |
{{Content: [subtitle text from the selected result]}} | |
{{Reasoning: [Brief explanation of why this result best answers the query]}} | |
""" | |
# Call LLM | |
print("Step 3: Querying the LLM...") | |
try: | |
output = llm.create_chat_completion( | |
messages=[ | |
{"role": "system", "content": "You are a helpful assistant designed to select the best video segment based on relevance to a query, following a specific output format."}, | |
{"role": "user", "content": prompt}, | |
], | |
temperature=0.1, | |
max_tokens=300 | |
) | |
llm_response_text = output['choices'][0]['message']['content'].strip() | |
print(f"LLM Response:\n---\n{llm_response_text}\n---") | |
except Exception as e: | |
print(f"Error during LLM call: {e}") | |
traceback.print_exc() | |
return None | |
# Parse LLM Response | |
print("Step 4: Parsing LLM response...") | |
parsed_data = parse_llm_output(llm_response_text) | |
video_id = parsed_data.get('video_id') | |
timestamp_str = parsed_data.get('timestamp') | |
if not video_id or not timestamp_str: | |
print("Error: Could not parse required video_id or timestamp from LLM response.") | |
print("Raw LLM response that failed parsing:\n---\n{llm_response_text}\n---") # Print raw output for debugging | |
# Return None for video output on parsing failure | |
return None | |
try: | |
timestamp = float(timestamp_str) | |
start_time = max(0.0, timestamp - (VIDEO_SEGMENT_DURATION / 4.0)) | |
actual_duration = VIDEO_SEGMENT_DURATION | |
print(f"Calculated segment start time: {start_time:.2f}s") | |
except ValueError: | |
print(f"Error: Could not convert parsed timestamp '{timestamp_str}' to float.") | |
# Return None for video output on invalid timestamp | |
return None | |
# Extract Video Segment | |
print(f"Step 5: Extracting video segment (ID: {video_id}, Start: {start_time:.2f}s, Duration: {actual_duration:.2f}s)...") | |
video_path = extract_video_segment(video_id, start_time, actual_duration, dataset) | |
if video_path and os.path.exists(video_path): | |
print(f"Video segment extracted successfully: {video_path}") | |
return video_path | |
else: | |
print("Failed to extract video segment.") | |
return None | |
with gr.Blocks() as iface: | |
gr.Markdown( | |
""" | |
# AI Lecture Video Q&A | |
Ask a question about the AI lectures. The system will find relevant segments using RAG, | |
ask a fine-tuned LLM to select the best one, and display the corresponding video clip. | |
""" | |
) | |
with gr.Row(): | |
query_input = gr.Textbox(label="Your Question", placeholder="e.g., What is a convolutional neural network?") | |
submit_button = gr.Button("Ask & Find Video") | |
with gr.Row(): | |
video_output = gr.Video(label="Relevant Video Segment", format="mp4") | |
submit_button.click( | |
fn=process_query_and_get_video, | |
inputs=query_input, | |
outputs=video_output | |
) | |
gr.Examples( | |
examples=[ | |
"Using only the videos, explain how ResNets work.", | |
"Using only the videos, explain the advantages of CNNs over fully connected networks.", | |
"Using only the videos, explain the the binary cross entropy loss function.", | |
], | |
inputs=query_input, | |
outputs=video_output, | |
fn=process_query_and_get_video, | |
cache_examples=False, | |
) | |
print("Launching Gradio interface...") | |
iface.launch(debug=True, share=False) |