ai_class_app / app.py
mitch
Updated app.py
4ab8943 unverified
raw
history blame
15.8 kB
import gradio as gr
import os
from llama_cpp import Llama
from qdrant_client import QdrantClient
from datasets import load_dataset
from sentence_transformers import SentenceTransformer
import tempfile
import uuid
import re
import subprocess
import traceback
QDRANT_COLLECTION_NAME = "video_frames"
VIDEO_SEGMENT_DURATION = 40 # Extract 40 seconds around the timestamp
# Load Secrets from Environment Variables
QDRANT_API_KEY = os.environ.get("QDRANT_API_KEY")
# Check for qdrant key
if not QDRANT_API_KEY:
print("Error: QDRANT_API_KEY environment variable not found.")
print("Please add your Qdrant API key as a secret named 'QDRANT_API_KEY' in your Hugging Face Space settings.")
raise ValueError("QDRANT_API_KEY environment variable not set.")
print("Initializing LLM...")
try:
llm = Llama.from_pretrained(
repo_id="m1tch/gemma-finetune-ai_class_gguf",
filename="gemma-3_ai_class.Q8_0.gguf",
n_gpu_layers=-1,
n_ctx=2048,
verbose=False
)
print("LLM initialized successfully.")
except Exception as e:
print(f"Error initializing LLM: {e}")
raise
print("Connecting to Qdrant...")
try:
qdrant_client = QdrantClient(
url="https://2c18d413-cbb5-441c-b060-4c8c2302dcde.us-east4-0.gcp.cloud.qdrant.io:6333/",
api_key=QDRANT_API_KEY,
timeout=60
)
qdrant_client.get_collections()
print("Qdrant connection successful.")
except Exception as e:
print(f"Error connecting to Qdrant: {e}")
raise
print("Loading dataset stream...")
try:
# Load video dataset
dataset = load_dataset("aegean-ai/ai-lectures-spring-24", split="train", streaming=True)
print(f"Dataset loaded.")
except Exception as e:
print(f"Error loading dataset: {e}")
raise
try:
embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
print("Sentence Transformer model loaded.")
except Exception as e:
print(f"Error loading Sentence Transformer model: {e}")
raise
def rag_query(client, collection_name, query_text, top_k=5, filter_condition=None):
"""
Test RAG by querying the vector database with text. Returns a dictionary with search results and metadata.
Uses the pre-loaded embedding_model.
"""
try:
query_vector = embedding_model.encode(query_text).tolist()
search_params = {
"collection_name": collection_name,
"query_vector": query_vector,
"limit": top_k,
"with_payload": True,
"with_vectors": False
}
if filter_condition:
search_params["filter"] = filter_condition
search_results = client.search(**search_params)
formatted_results = []
for idx, result in enumerate(search_results):
formatted_results.append({
"rank": idx + 1,
"score": result.score,
"video_id": result.payload.get("video_id"),
"timestamp": result.payload.get("timestamp"),
"subtitle": result.payload.get("subtitle"),
"frame_number": result.payload.get("frame_number")
})
return {
"query": query_text,
"results": formatted_results,
"avg_score": sum(r.score for r in search_results) / len(search_results) if search_results else 0
}
except Exception as e:
print(f"Error during RAG query: {e}")
traceback.print_exc()
return {"error": str(e), "query": query_text, "results": []}
def extract_video_segment(video_id, start_time, duration, dataset):
"""
Extracts a single video segment file path from the dataset stream.
Returns a single path suitable for Gradio or None on failure.
"""
target_id = str(video_id)
target_key_pattern = re.compile(r"videos/" + re.escape(target_id) + r"/" + re.escape(target_id))
start_time = float(start_time)
duration = float(duration)
unique_id = str(uuid.uuid4())
temp_dir = os.path.join(tempfile.gettempdir(), f"gradio_video_seg_{unique_id}")
os.makedirs(temp_dir, exist_ok=True)
temp_video_path_full = os.path.join(temp_dir, f"{target_id}_full_{unique_id}.mp4")
output_path_ffmpeg = os.path.join(temp_dir, f"output_ffmpeg_{unique_id}.mp4")
print(f"Attempting to extract segment for video_id={target_id}, start={start_time:.2f}, duration={duration:.2f}")
print(f"Looking for dataset key matching pattern: {target_key_pattern.pattern}")
print(f"Temporary directory: {temp_dir}")
found_sample = None
max_search_attempts = 1000 # Limit
print(f"Searching dataset stream for key matching pattern: {target_key_pattern.pattern}")
dataset_iterator = iter(dataset)
try:
# Find and save the full video from the stream
for i in range(max_search_attempts):
try:
sample = next(dataset_iterator)
if '__key__' in sample and 'mp4' in sample and target_key_pattern.match(sample['__key__']):
print(f"Found video key {sample['__key__']} after {i+1} iterations. Saving to {temp_video_path_full}...")
with open(temp_video_path_full, 'wb') as f:
f.write(sample['mp4'])
print(f"Video saved successfully ({os.path.getsize(temp_video_path_full)} bytes).")
found_sample = sample
break # Found the video
except StopIteration:
print("Reached end of dataset stream without finding the video within search limit.")
break
except Exception as e:
print(f"Warning: Error iterating dataset sample {i+1}: {e}")
if not found_sample or not os.path.exists(temp_video_path_full) or os.path.getsize(temp_video_path_full) == 0:
print(f"Could not find or save video with ID {target_id} from dataset stream.")
return None
# Process the saved video with FFmpeg
final_output_path = None
try:
cmd = [
'ffmpeg',
'-y',
'-ss', str(start_time),
'-i', temp_video_path_full,
'-t', str(duration),
'-c:v', 'libx264',
'-profile:v', 'baseline',
'-level', '3.0',
'-preset', 'fast',
'-pix_fmt', 'yuv420p',
'-movflags', '+faststart',
'-c:a', 'aac',
'-b:a', '128k',
'-vf', f'select=gte(t,{start_time})',
'-vsync', 'vfr',
output_path_ffmpeg
]
print(f"Running FFmpeg command: {' '.join(cmd)}")
result = subprocess.run(cmd, capture_output=True, text=True, timeout=120)
if result.returncode == 0 and os.path.exists(output_path_ffmpeg) and os.path.getsize(output_path_ffmpeg) > 0:
print(f"FFmpeg processing successful. Output: {output_path_ffmpeg}")
final_output_path = output_path_ffmpeg
else:
print(f"FFmpeg error (Return Code: {result.returncode}):")
print(f"FFmpeg stdout:\n{result.stdout}")
print(f"FFmpeg stderr:\n{result.stderr}")
print("FFmpeg failed.")
final_output_path = None
except subprocess.TimeoutExpired:
print("FFmpeg command timed out.")
final_output_path = None
except FileNotFoundError:
print("Error: ffmpeg command not found. Make sure FFmpeg is installed.")
final_output_path = None
except Exception as e:
print(f"An unexpected error occurred during FFmpeg processing: {e}")
traceback.print_exc()
final_output_path = None
finally:
# Clean up temporary files
print(f"Cleaning up temporary directory: {temp_dir}")
if os.path.exists(temp_video_path_full):
try:
os.remove(temp_video_path_full)
print(f"Cleaned up temporary full video: {temp_video_path_full}")
except Exception as e:
print(f"Warning: Could not remove temporary file {temp_video_path_full}: {e}")
# Clean up failed FFmpeg output if it exists and wasn't the final path
if final_output_path != output_path_ffmpeg and os.path.exists(output_path_ffmpeg):
try:
os.remove(output_path_ffmpeg)
except Exception as e:
print(f"Warning: Could not remove failed ffmpeg output {output_path_ffmpeg}: {e}")
# Return the path of the successfully created segment or None
if final_output_path and os.path.exists(final_output_path):
print(f"Returning video segment path: {final_output_path}")
return final_output_path
else:
print("Video segment extraction failed.")
return None
def parse_llm_output(text):
"""
Parses the LLM's structured output using string manipulation.
"""
data = {}
print(f"\nDEBUG: Raw text input to parse_llm_output:\n---\n{text}\n---")
def extract_field(text, field_name):
start_marker_lower = "{" + field_name.lower() + ":"
start_index = text.lower().find(start_marker_lower)
if start_index != -1:
actual_marker_end = start_index + len(start_marker_lower)
end_index = text.find('}', actual_marker_end)
if end_index != -1:
value = text[actual_marker_end : end_index]
value = value.strip()
if value.startswith('[') and value.endswith(']'):
value = value[1:-1].strip()
value = value.strip('\'"“”')
return value.strip()
else:
print(f"Warning: Found '{{{field_name}:' marker but no closing '}}' found afterwards.")
else:
print(f"Warning: Marker '{{{field_name}:' not found in text.")
return None
# Extract fields
data['video_id'] = extract_field(text, 'Best Result')
data['timestamp'] = extract_field(text, 'Timestamp')
data['content'] = extract_field(text, 'Content')
data['reasoning'] = extract_field(text, 'Reasoning')
if data.get('timestamp'):
try:
float(data['timestamp'])
except ValueError:
print(f"Warning: Parsed timestamp '{data['timestamp']}' is not a valid number.")
data['timestamp'] = None
print(f"Parsed LLM output: {data}")
return data
def process_query_and_get_video(query_text):
"""
Orchestrates RAG, LLM query, parsing, and video extraction.
Returns only the video path or None.
"""
print(f"\n--- Processing query: '{query_text}' ---")
# RAG Query
print("Step 1: Performing RAG query...")
rag_results = rag_query(qdrant_client, QDRANT_COLLECTION_NAME, query_text)
if "error" in rag_results or not rag_results.get("results"):
error_msg = rag_results.get('error', 'No relevant segments found by RAG.')
print(f"RAG Error/No Results: {error_msg}")
# Return None for video output on RAG failure
return None
print(f"RAG query successful. Found {len(rag_results['results'])} results.")
# Format LLM Prompt
print("Step 2: Formatting prompt for LLM...")
results_for_llm = "\n".join([
f"Rank: {r['rank']}, Score: {r['score']:.4f}, Video ID: {r['video_id']}, Timestamp: {r['timestamp']}, Subtitle: {r['subtitle']}"
for r in rag_results['results']
])
prompt = f"""You are tasked with selecting the most relevant information from a set of video subtitle segments to answer a query.
QUERY: "{query_text}"
Here are the relevant video segments found:
---
{results_for_llm}
---
For each result provided, evaluate how well it directly addresses the definition or explanation related to the query. Pay attention to:
1. Clarity of explanation
2. Relevance to the query
3. Completeness of information
From the provided results, select the SINGLE BEST match that most directly answers the query.
Format your response STRICTLY as follows, with each field on a new line:
{{Best Result: [video_id]}}
{{Timestamp: [timestamp]}}
{{Content: [subtitle text from the selected result]}}
{{Reasoning: [Brief explanation of why this result best answers the query]}}
"""
# Call LLM
print("Step 3: Querying the LLM...")
try:
output = llm.create_chat_completion(
messages=[
{"role": "system", "content": "You are a helpful assistant designed to select the best video segment based on relevance to a query, following a specific output format."},
{"role": "user", "content": prompt},
],
temperature=0.1,
max_tokens=300
)
llm_response_text = output['choices'][0]['message']['content'].strip()
print(f"LLM Response:\n---\n{llm_response_text}\n---")
except Exception as e:
print(f"Error during LLM call: {e}")
traceback.print_exc()
return None
# Parse LLM Response
print("Step 4: Parsing LLM response...")
parsed_data = parse_llm_output(llm_response_text)
video_id = parsed_data.get('video_id')
timestamp_str = parsed_data.get('timestamp')
if not video_id or not timestamp_str:
print("Error: Could not parse required video_id or timestamp from LLM response.")
print("Raw LLM response that failed parsing:\n---\n{llm_response_text}\n---") # Print raw output for debugging
# Return None for video output on parsing failure
return None
try:
timestamp = float(timestamp_str)
start_time = max(0.0, timestamp - (VIDEO_SEGMENT_DURATION / 4.0))
actual_duration = VIDEO_SEGMENT_DURATION
print(f"Calculated segment start time: {start_time:.2f}s")
except ValueError:
print(f"Error: Could not convert parsed timestamp '{timestamp_str}' to float.")
# Return None for video output on invalid timestamp
return None
# Extract Video Segment
print(f"Step 5: Extracting video segment (ID: {video_id}, Start: {start_time:.2f}s, Duration: {actual_duration:.2f}s)...")
video_path = extract_video_segment(video_id, start_time, actual_duration, dataset)
if video_path and os.path.exists(video_path):
print(f"Video segment extracted successfully: {video_path}")
return video_path
else:
print("Failed to extract video segment.")
return None
with gr.Blocks() as iface:
gr.Markdown(
"""
# AI Lecture Video Q&A
Ask a question about the AI lectures. The system will find relevant segments using RAG,
ask a fine-tuned LLM to select the best one, and display the corresponding video clip.
"""
)
with gr.Row():
query_input = gr.Textbox(label="Your Question", placeholder="e.g., What is a convolutional neural network?")
submit_button = gr.Button("Ask & Find Video")
with gr.Row():
video_output = gr.Video(label="Relevant Video Segment", format="mp4")
submit_button.click(
fn=process_query_and_get_video,
inputs=query_input,
outputs=video_output
)
gr.Examples(
examples=[
"Using only the videos, explain how ResNets work.",
"Using only the videos, explain the advantages of CNNs over fully connected networks.",
"Using only the videos, explain the the binary cross entropy loss function.",
],
inputs=query_input,
outputs=video_output,
fn=process_query_and_get_video,
cache_examples=False,
)
print("Launching Gradio interface...")
iface.launch(debug=True, share=False)