rag-youtube-assistant / app /generate_ground_truth.py
ganesh3's picture
Update app/generate_ground_truth.py
e2cbf8c verified
import pandas as pd
import json
from tqdm import tqdm
import logging
import os
import sys
import re
import sqlite3
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
stream=sys.stdout
)
logger = logging.getLogger(__name__)
def extract_model_name(index_name):
match = re.search(r'video_[^_]+_(.+)$', index_name)
if match:
return match.group(1)
return None
def get_transcript_from_sqlite(db_path, video_id):
try:
conn = sqlite3.connect(db_path)
cursor = conn.cursor()
cursor.execute("SELECT transcript_content FROM videos WHERE youtube_id = ?", (video_id,))
result = cursor.fetchone()
conn.close()
if result:
return result[0]
except Exception as e:
logger.error(f"Error retrieving transcript from SQLite: {str(e)}")
return None
def generate_questions(transcript, max_retries=3):
prompt_template = """
You are an AI assistant tasked with generating questions based on a YouTube video transcript.
Formulate EXACTLY 10 questions that a user might ask based on the provided transcript.
Make the questions specific to the content of the transcript.
The questions should be complete and not too short. Use as few words as possible from the transcript.
Ensure that all 10 questions are unique and not repetitive.
The transcript:
{transcript}
Provide the output in parsable JSON without using code blocks:
{{"questions": ["question1", "question2", ..., "question10"]}}
""".strip()
all_questions = set()
retries = 0
while len(all_questions) < 10 and retries < max_retries:
try:
model = pipeline("text-generation", model="google/flan-t5-base", device=-1)
response = model(prompt_template.format(transcript=transcript),
max_length=1024,
num_return_sequences=1)[0]['generated_text']
questions = json.loads(response)['questions']
all_questions.update(questions)
except Exception as e:
logger.error(f"Error generating questions: {str(e)}")
retries += 1
if len(all_questions) < 10:
logger.warning(f"Could only generate {len(all_questions)} unique questions after {max_retries} attempts.")
return {"questions": list(all_questions)[:10]}
def generate_ground_truth(db_handler, data_processor, video_id):
# Get existing questions for this video to avoid duplicates
existing_questions = set(q[1] for q in db_handler.get_ground_truth_by_video(video_id))
# Get transcript from SQLite
transcript = get_transcript_from_sqlite(db_handler.db_path, video_id)
if not transcript:
logger.error(f"Failed to retrieve transcript for video {video_id}")
return None
# Generate questions until we have 10 unique ones
all_questions = set()
max_attempts = 3
attempts = 0
while len(all_questions) < 10 and attempts < max_attempts:
questions = generate_questions(transcript)
if questions and 'questions' in questions:
new_questions = set(questions['questions']) - existing_questions
all_questions.update(new_questions)
attempts += 1
if not all_questions:
logger.error("Failed to generate any unique questions.")
return None
# Store questions in database
db_handler.add_ground_truth_questions(video_id, all_questions)
# Create DataFrame and save to CSV
df = pd.DataFrame([(video_id, q) for q in all_questions], columns=['video_id', 'question'])
csv_path = 'data/ground-truth-retrieval.csv'
# Append to existing CSV if it exists, otherwise create new
if os.path.exists(csv_path):
df.to_csv(csv_path, mode='a', header=False, index=False)
else:
df.to_csv(csv_path, index=False)
logger.info(f"Ground truth data saved to {csv_path}")
return df
def get_evaluation_display_data(video_id=None):
"""Get evaluation data from CSV file"""
try:
csv_df = pd.read_csv('data/evaluation_results.csv')
if video_id:
csv_df = csv_df[csv_df['video_id'] == video_id]
return csv_df
except FileNotFoundError:
return pd.DataFrame()
def get_ground_truth_display_data(db_handler, video_id=None, channel_name=None):
"""Get ground truth data from both database and CSV file"""
# Try to get data from database first
if video_id:
data = db_handler.get_ground_truth_by_video(video_id)
elif channel_name:
data = db_handler.get_ground_truth_by_channel(channel_name)
else:
data = []
# Create DataFrame from database data
if data:
db_df = pd.DataFrame(data, columns=['id', 'video_id', 'question', 'generation_date', 'channel_name'])
else:
db_df = pd.DataFrame()
# Try to get data from CSV
try:
csv_df = pd.read_csv('data/ground-truth-retrieval.csv')
if video_id:
csv_df = csv_df[csv_df['video_id'] == video_id]
elif channel_name:
# Join with videos table to get channel information
videos_df = pd.DataFrame(db_handler.get_all_videos(),
columns=['youtube_id', 'title', 'channel_name', 'upload_date'])
csv_df = csv_df.merge(videos_df, left_on='video_id', right_on='youtube_id')
csv_df = csv_df[csv_df['channel_name'] == channel_name]
except FileNotFoundError:
csv_df = pd.DataFrame()
# Combine data from both sources
if not db_df.empty and not csv_df.empty:
combined_df = pd.concat([db_df, csv_df]).drop_duplicates(subset=['video_id', 'question'])
elif not db_df.empty:
combined_df = db_df
elif not csv_df.empty:
combined_df = csv_df
else:
combined_df = pd.DataFrame()
return combined_df
def generate_ground_truth_for_all_videos(db_handler, data_processor):
videos = db_handler.get_all_videos()
all_questions = []
for video in tqdm(videos, desc="Generating ground truth"):
video_id = video[0] # Assuming the video ID is the first element in the tuple
df = generate_ground_truth(db_handler, data_processor, video_id)
if df is not None:
all_questions.extend(df.values.tolist())
if all_questions:
df = pd.DataFrame(all_questions, columns=['video_id', 'question'])
csv_path = 'data/ground-truth-retrieval.csv'
df.to_csv(csv_path, index=False)
logger.info(f"Ground truth data for all videos saved to {csv_path}")
return df
else:
logger.error("Failed to generate questions for any video.")
return None