Spaces:
Running
Running
File size: 6,791 Bytes
dbd33b2 a61b32e 982a34d e2cbf8c a61b32e e2cbf8c 6174942 a61b32e dbd33b2 66a5452 dbd33b2 66a5452 dbd33b2 66a5452 dbd33b2 66a5452 dbd33b2 66a5452 e2cbf8c 66a5452 25b2b2b 66a5452 a61b32e e2cbf8c a61b32e dbd33b2 66a5452 dbd33b2 66a5452 25b2b2b 66a5452 e2cbf8c 66a5452 25b2b2b 66a5452 25b2b2b a61b32e 25b2b2b a61b32e 25b2b2b dbd33b2 a61b32e e2cbf8c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 |
import pandas as pd
import json
from tqdm import tqdm
import logging
import os
import sys
import re
import sqlite3
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
stream=sys.stdout
)
logger = logging.getLogger(__name__)
def extract_model_name(index_name):
match = re.search(r'video_[^_]+_(.+)$', index_name)
if match:
return match.group(1)
return None
def get_transcript_from_sqlite(db_path, video_id):
try:
conn = sqlite3.connect(db_path)
cursor = conn.cursor()
cursor.execute("SELECT transcript_content FROM videos WHERE youtube_id = ?", (video_id,))
result = cursor.fetchone()
conn.close()
if result:
return result[0]
except Exception as e:
logger.error(f"Error retrieving transcript from SQLite: {str(e)}")
return None
def generate_questions(transcript, max_retries=3):
prompt_template = """
You are an AI assistant tasked with generating questions based on a YouTube video transcript.
Formulate EXACTLY 10 questions that a user might ask based on the provided transcript.
Make the questions specific to the content of the transcript.
The questions should be complete and not too short. Use as few words as possible from the transcript.
Ensure that all 10 questions are unique and not repetitive.
The transcript:
{transcript}
Provide the output in parsable JSON without using code blocks:
{{"questions": ["question1", "question2", ..., "question10"]}}
""".strip()
all_questions = set()
retries = 0
while len(all_questions) < 10 and retries < max_retries:
try:
model = pipeline("text-generation", model="google/flan-t5-base", device=-1)
response = model(prompt_template.format(transcript=transcript),
max_length=1024,
num_return_sequences=1)[0]['generated_text']
questions = json.loads(response)['questions']
all_questions.update(questions)
except Exception as e:
logger.error(f"Error generating questions: {str(e)}")
retries += 1
if len(all_questions) < 10:
logger.warning(f"Could only generate {len(all_questions)} unique questions after {max_retries} attempts.")
return {"questions": list(all_questions)[:10]}
def generate_ground_truth(db_handler, data_processor, video_id):
# Get existing questions for this video to avoid duplicates
existing_questions = set(q[1] for q in db_handler.get_ground_truth_by_video(video_id))
# Get transcript from SQLite
transcript = get_transcript_from_sqlite(db_handler.db_path, video_id)
if not transcript:
logger.error(f"Failed to retrieve transcript for video {video_id}")
return None
# Generate questions until we have 10 unique ones
all_questions = set()
max_attempts = 3
attempts = 0
while len(all_questions) < 10 and attempts < max_attempts:
questions = generate_questions(transcript)
if questions and 'questions' in questions:
new_questions = set(questions['questions']) - existing_questions
all_questions.update(new_questions)
attempts += 1
if not all_questions:
logger.error("Failed to generate any unique questions.")
return None
# Store questions in database
db_handler.add_ground_truth_questions(video_id, all_questions)
# Create DataFrame and save to CSV
df = pd.DataFrame([(video_id, q) for q in all_questions], columns=['video_id', 'question'])
csv_path = 'data/ground-truth-retrieval.csv'
# Append to existing CSV if it exists, otherwise create new
if os.path.exists(csv_path):
df.to_csv(csv_path, mode='a', header=False, index=False)
else:
df.to_csv(csv_path, index=False)
logger.info(f"Ground truth data saved to {csv_path}")
return df
def get_evaluation_display_data(video_id=None):
"""Get evaluation data from CSV file"""
try:
csv_df = pd.read_csv('data/evaluation_results.csv')
if video_id:
csv_df = csv_df[csv_df['video_id'] == video_id]
return csv_df
except FileNotFoundError:
return pd.DataFrame()
def get_ground_truth_display_data(db_handler, video_id=None, channel_name=None):
"""Get ground truth data from both database and CSV file"""
# Try to get data from database first
if video_id:
data = db_handler.get_ground_truth_by_video(video_id)
elif channel_name:
data = db_handler.get_ground_truth_by_channel(channel_name)
else:
data = []
# Create DataFrame from database data
if data:
db_df = pd.DataFrame(data, columns=['id', 'video_id', 'question', 'generation_date', 'channel_name'])
else:
db_df = pd.DataFrame()
# Try to get data from CSV
try:
csv_df = pd.read_csv('data/ground-truth-retrieval.csv')
if video_id:
csv_df = csv_df[csv_df['video_id'] == video_id]
elif channel_name:
# Join with videos table to get channel information
videos_df = pd.DataFrame(db_handler.get_all_videos(),
columns=['youtube_id', 'title', 'channel_name', 'upload_date'])
csv_df = csv_df.merge(videos_df, left_on='video_id', right_on='youtube_id')
csv_df = csv_df[csv_df['channel_name'] == channel_name]
except FileNotFoundError:
csv_df = pd.DataFrame()
# Combine data from both sources
if not db_df.empty and not csv_df.empty:
combined_df = pd.concat([db_df, csv_df]).drop_duplicates(subset=['video_id', 'question'])
elif not db_df.empty:
combined_df = db_df
elif not csv_df.empty:
combined_df = csv_df
else:
combined_df = pd.DataFrame()
return combined_df
def generate_ground_truth_for_all_videos(db_handler, data_processor):
videos = db_handler.get_all_videos()
all_questions = []
for video in tqdm(videos, desc="Generating ground truth"):
video_id = video[0] # Assuming the video ID is the first element in the tuple
df = generate_ground_truth(db_handler, data_processor, video_id)
if df is not None:
all_questions.extend(df.values.tolist())
if all_questions:
df = pd.DataFrame(all_questions, columns=['video_id', 'question'])
csv_path = 'data/ground-truth-retrieval.csv'
df.to_csv(csv_path, index=False)
logger.info(f"Ground truth data for all videos saved to {csv_path}")
return df
else:
logger.error("Failed to generate questions for any video.")
return None |